Ejemplo n.º 1
0
def test_parameter_update():
    original_param_values = {}
    original_bn_stat_values = {}

    model = get_model_for_test()

    for layer in model.layers:
        if get_keras_layer_metatype(
                layer) == TFBatchNormalizationLayerMetatype:
            original_bn_stat_values[layer] = deepcopy(
                layer.non_trainable_weights)
            original_param_values[layer] = deepcopy(layer.trainable_weights)
        else:
            original_param_values[layer] = deepcopy(layer.weights)

    config = get_config_for_test()

    bn_adaptation = BatchnormAdaptationAlgorithm(
        **extract_bn_adaptation_init_params(config, "quantization"))
    bn_adaptation.run(model)

    for layer in model.layers:
        if get_keras_layer_metatype(
                layer) == TFBatchNormalizationLayerMetatype:
            compare_params(original_bn_stat_values[layer],
                           layer.non_trainable_weights,
                           equal=False)
            compare_params(original_param_values[layer],
                           layer.trainable_weights)
        else:
            compare_params(original_param_values[layer], layer.weights)
Ejemplo n.º 2
0
def test_all_parameter_keep():
    original_all_param_values = {}

    model = get_model_for_test()

    for layer in model.layers:
        original_all_param_values[layer] = deepcopy(layer.weights)

    config = get_config_for_test(num_bn_adaptation_samples=0)

    bn_adaptation = BatchnormAdaptationAlgorithm(
        **extract_bn_adaptation_init_params(config, "quantization"))
    bn_adaptation.run(model)

    for layer in model.layers:
        compare_params(original_all_param_values[layer], layer.weights)
Ejemplo n.º 3
0
 def _run_batchnorm_adaptation(self):
     if self._bn_adaptation is None:
         self._bn_adaptation = BatchnormAdaptationAlgorithm(
             **extract_bn_adaptation_init_params(self._config,
                                                 'magnitude_sparsity'))
     self._bn_adaptation.run(self.model)
Ejemplo n.º 4
0
class MagnitudeSparsityController(BaseSparsityController):
    """
    Serves as a handle to the additional modules, parameters and hooks inserted
    into the original uncompressed model in order to enable algorithm-specific compression.
    Hosts entities that are to be used during the training process, such as compression scheduler and
    compression loss.
    """
    def __init__(self, target_model, config: NNCFConfig, op_names):
        super().__init__(target_model, op_names)
        algo_config = extract_algo_specific_config(config,
                                                   'magnitude_sparsity')
        params = deepcopy(algo_config.get('params', {}))
        self._threshold = 0
        self._frozen = False
        self._weight_importance_fn = WEIGHT_IMPORTANCE_FUNCTIONS[params.get(
            'weight_importance', 'normed_abs')]

        sparsity_init = algo_config.get('sparsity_init', 0)
        params['sparsity_init'] = sparsity_init
        scheduler_type = params.get('schedule', 'polynomial')

        if scheduler_type == 'adaptive':
            raise ValueError(
                'Magnitude sparsity algorithm do not support adaptive scheduler'
            )

        scheduler_cls = SPARSITY_SCHEDULERS.get(scheduler_type)
        self._scheduler = scheduler_cls(self, params)
        self._loss = TFZeroCompressionLoss()
        self._bn_adaptation = None
        self._config = config
        self.set_sparsity_level(sparsity_init)

    @property
    def scheduler(self) -> CompressionScheduler:
        return self._scheduler

    @property
    def loss(self) -> CompressionLoss:
        return self._loss

    def freeze(self, freeze: bool = True):
        self._frozen = freeze

    def set_sparsity_level(self,
                           sparsity_level,
                           run_batchnorm_adaptation: bool = False):
        if not self._frozen:
            if sparsity_level >= 1 or sparsity_level < 0:
                raise AttributeError(
                    'Sparsity level should be within interval [0,1), actual value to set is: {}'
                    .format(sparsity_level))

            self._threshold = self._select_threshold(sparsity_level)
            self._set_masks_for_threshold(self._threshold)

        if run_batchnorm_adaptation:
            self._run_batchnorm_adaptation()

    def _select_threshold(self, sparsity_level):
        all_weights = self._collect_all_weights()
        if not all_weights:
            return 0.0
        all_weights_tensor = tf.sort(tf.concat(all_weights, 0))
        index = int(
            tf.cast(tf.size(all_weights_tensor), all_weights_tensor.dtype) *
            sparsity_level)
        threshold = all_weights_tensor[index].numpy()
        return threshold

    def _set_masks_for_threshold(self, threshold_val):
        for wrapped_layer in collect_wrapped_layers(self._model):
            for weight_attr, ops in wrapped_layer.weights_attr_ops.items():
                weight = wrapped_layer.layer_weights[weight_attr]

                for op_name in ops:
                    if op_name in self._op_names:
                        wrapped_layer.ops_weights[op_name]['mask'].assign(
                            calc_magnitude_binary_mask(
                                weight, self._weight_importance_fn,
                                threshold_val))

    def _collect_all_weights(self):
        all_weights = []
        for wrapped_layer in collect_wrapped_layers(self._model):
            for weight_attr, ops in wrapped_layer.weights_attr_ops.items():
                for op_name in ops:
                    if op_name in self._op_names:
                        all_weights.append(
                            tf.reshape(
                                self._weight_importance_fn(
                                    wrapped_layer.layer_weights[weight_attr]),
                                [-1]))
        return all_weights

    @property
    def compression_rate(self) -> float:
        return self.statistics(
        ).magnitude_sparsity.model_statistics.sparsity_level

    @compression_rate.setter
    def compression_rate(self, compression_rate: float) -> None:
        self.freeze(False)
        self.set_sparsity_level(compression_rate)
        self.freeze(True)

    def disable_scheduler(self):
        self._scheduler = StubCompressionScheduler()
        self._scheduler.current_sparsity_level = 0.0

    def statistics(self,
                   quickly_collected_only: bool = False) -> NNCFStatistics:
        collector = TFSparseModelStatisticsCollector(self.model,
                                                     self._op_names)
        model_stats = collector.collect()

        threshold_stats = []
        threshold = self._select_threshold(model_stats.sparsity_level)
        for s in model_stats.sparsified_layers_summary:
            threshold_stats.append(LayerThreshold(s.name, threshold))

        target_sparsity_level = self.scheduler.current_sparsity_level

        stats = MagnitudeSparsityStatistics(model_stats, threshold_stats,
                                            target_sparsity_level)

        nncf_stats = NNCFStatistics()
        nncf_stats.register('magnitude_sparsity', stats)
        return nncf_stats

    def _run_batchnorm_adaptation(self):
        if self._bn_adaptation is None:
            self._bn_adaptation = BatchnormAdaptationAlgorithm(
                **extract_bn_adaptation_init_params(self._config,
                                                    'magnitude_sparsity'))
        self._bn_adaptation.run(self.model)
Ejemplo n.º 5
0
 def _run_batchnorm_adaptation(self, model: tf.keras.Model) -> None:
     if self._bn_adaptation is None:
         self._bn_adaptation = BatchnormAdaptationAlgorithm(
             **extract_bn_adaptation_init_params(self.config, self.name))
     self._bn_adaptation.run(model)
Ejemplo n.º 6
0
class QuantizationBuilder(TFCompressionAlgorithmBuilder):
    _state_names = QBuilderStateNames

    DEFAULT_QCONFIG = QuantizerConfig(num_bits=8,
                                      mode=QuantizationMode.SYMMETRIC,
                                      signedness_to_force=None,
                                      per_channel=False)

    def __init__(self, config: NNCFConfig, should_init: bool = True):
        super().__init__(config, should_init)

        self.quantize_inputs = self._algo_config.get('quantize_inputs', True)
        self.quantize_outputs = self._algo_config.get('quantize_outputs',
                                                      False)
        self._overflow_fix = self._algo_config.get('overflow_fix', 'enable')
        self._target_device = config.get('target_device', 'ANY')
        algo_config = self._get_algo_specific_config_section()
        if self._target_device == 'VPU' and 'preset' in algo_config:
            raise RuntimeError(
                "The VPU target device does not support presets.")

        self.global_quantizer_constraints = {}
        self.ignored_scopes_per_group = {}
        self.target_scopes_per_group = {}
        self._op_names = []

        for quantizer_group in QuantizerGroup:
            self._parse_group_params(self._algo_config, quantizer_group)

        if self.should_init:
            self._parse_init_params()

        self._range_initializer = None
        self._bn_adaptation = None
        self._quantizer_setup = None

        self.hw_config = None
        if self._target_device != "TRIAL":
            hw_config_type = HWConfigType.from_str(
                HW_CONFIG_TYPE_TARGET_DEVICE_MAP[self._target_device])
            hw_config_path = TFHWConfig.get_path_to_hw_config(hw_config_type)
            self.hw_config = TFHWConfig.from_json(hw_config_path)

    def _load_state_without_name(self, state_without_name: Dict[str, Any]):
        """
        Initializes object from the state.

        :param state_without_name: Output of `get_state()` method.
        """
        quantizer_setup_state = state_without_name[
            self._state_names.QUANTIZER_SETUP]
        self._quantizer_setup = TFQuantizationSetup.from_state(
            quantizer_setup_state)

    def _get_state_without_name(self) -> Dict[str, Any]:
        """
        Returns a dictionary with Python data structures (dict, list, tuple, str, int, float, True, False, None) that
        represents state of the object.

        :return: state of the object
        """
        quantizer_setup_state = self._quantizer_setup.get_state()
        return {self._state_names.QUANTIZER_SETUP: quantizer_setup_state}

    def _parse_init_params(self):
        self._range_init_params = self._parse_range_init_params()

    def _parse_range_init_params(self) -> TFRangeInitParams:
        range_init_params = extract_range_init_params(self.config)
        return TFRangeInitParams(
            **range_init_params) if range_init_params is not None else None

    def _parse_group_params(self, quant_config: Dict,
                            quantizer_group: QuantizerGroup) -> None:
        group_name = quantizer_group.value
        params_dict = {}
        params_dict_from_config = quant_config.get(group_name, {})
        preset = quant_config.get('preset')
        if self._target_device in [
                'ANY', 'CPU', 'GPU'
        ] or self._target_device == 'TRIAL' and preset is not None:
            preset = QuantizationPreset.from_str(
                quant_config.get('preset', 'performance'))
            params_dict = preset.get_params_configured_by_preset(
                quantizer_group)
            overrided_params = params_dict.keys(
            ) & params_dict_from_config.keys()
            if overrided_params:
                logger.warning(
                    'Preset quantizer parameters {} explicitly overrided.'.
                    format(overrided_params))
        params_dict.update(params_dict_from_config)
        self.global_quantizer_constraints[
            quantizer_group] = QuantizationConstraints.from_config_dict(
                params_dict)
        self.ignored_scopes_per_group[
            quantizer_group] = params_dict_from_config.get(
                'ignored_scopes', [])
        if self.ignored_scopes is not None:
            self.ignored_scopes_per_group[
                quantizer_group] += self.ignored_scopes
        target_scopes = params_dict_from_config.get('target_scopes')
        if target_scopes is None and self.target_scopes is not None:
            self.target_scopes_per_group[quantizer_group] = self.target_scopes
        else:
            self.target_scopes_per_group[quantizer_group] = target_scopes

    def _get_default_qconfig(
            self,
            constraints: QuantizationConstraints = None) -> QuantizerConfig:
        qconfig = deepcopy(self.DEFAULT_QCONFIG)
        if constraints is not None:
            qconfig = constraints.apply_constraints_to(qconfig)
        return qconfig

    def _get_half_range(self, qconfig: QuantizerConfig, target_node: NNCFNode,
                        first_conv_nodes: List[NNCFNode]) -> bool:
        if self._target_device in ['CPU', 'ANY'] and qconfig.num_bits == 8:
            if self._overflow_fix == 'enable':
                return True
            if self._overflow_fix == 'first_layer_only':
                if target_node in first_conv_nodes:
                    return True
        return False

    def _create_quantizer(self, name: str,
                          qspec: TFQuantizerSpec) -> Quantizer:
        quantizer_cls = NNCF_QUANTIZATION_OPERATIONS.get(qspec.mode)
        return quantizer_cls(name, qspec)

    def _build_insertion_commands_for_quantizer_setup(self,
                                                      quantizer_setup: TFQuantizationSetup) \
            -> List[TFInsertionCommand]:
        insertion_commands = []
        quantization_points = quantizer_setup.get_quantization_points()
        non_unified_scales_quantization_point_ids = set(
            range(len(quantization_points)))

        for unified_scales_group in quantizer_setup.get_unified_scale_groups():
            us_qp_id = unified_scales_group[0]
            qp = quantization_points[us_qp_id]
            quantizer_spec = qp.quantizer_spec
            op_name = qp.op_name + '/unified_scale_group'
            quantizer = FakeQuantize(quantizer_spec, name=op_name)
            self._op_names.append(quantizer.op_name)
            target_points = []
            for us_qp_id in unified_scales_group:
                non_unified_scales_quantization_point_ids.discard(us_qp_id)
                qp = quantization_points[us_qp_id]
                assert quantizer_spec.get_state(
                ) == qp.quantizer_spec.get_state()
                target_points.append(qp.target_point)

            command = TFInsertionCommand(
                target_point=TFMultiLayerPoint(target_points),
                callable_object=quantizer,
                priority=TransformationPriority.QUANTIZATION_PRIORITY)

            insertion_commands.append(command)

        for qp_id in non_unified_scales_quantization_point_ids:
            quantization_point = quantization_points[qp_id]
            op_name = quantization_point.op_name
            quantizer_spec = quantization_point.quantizer_spec
            target_point = quantization_point.target_point
            if quantization_point.is_weight_quantization():
                quantizer = self._create_quantizer(op_name, quantizer_spec)
                self._op_names.append(op_name)
            else:
                quantizer = FakeQuantize(quantizer_spec, name=op_name)
                self._op_names.append(quantizer.op_name)
            command = TFInsertionCommand(
                target_point=target_point,
                callable_object=quantizer,
                priority=TransformationPriority.QUANTIZATION_PRIORITY)
            insertion_commands.append(command)
        return insertion_commands

    def get_transformation_layout(
            self, model: tf.keras.Model) -> TFTransformationLayout:
        transformations = TFTransformationLayout()
        if self._quantizer_setup is None:
            self._quantizer_setup = self._get_quantizer_setup(model)
        insertion_commands = self._build_insertion_commands_for_quantizer_setup(
            self._quantizer_setup)
        for command in insertion_commands:
            transformations.register(command)
        return transformations

    def _get_custom_layer_node_names(
            self, nncf_graph: NNCFGraph,
            converter: TFModelConverter) -> List[NNCFNodeName]:
        retval = []
        for node in nncf_graph.get_all_nodes():
            metatype = node.metatype
            if metatype in OUTPUT_NOOP_METATYPES:
                continue
            is_custom, _ = converter.get_layer_info_for_node(node.node_name)
            if is_custom:
                retval.append(node.node_name)
        return retval

    def _build_controller(self,
                          model: tf.keras.Model) -> 'QuantizationController':
        return QuantizationController(model, self.config, self._op_names)

    def initialize(self, model: tf.keras.Model) -> None:
        if self._range_init_params is not None:
            self._run_range_initialization(model)
        self._run_batchnorm_adaptation(model)

    def _run_range_initialization(self, model: tf.keras.Model) -> None:
        if self._range_initializer is None:
            self._range_initializer = RangeInitializer(self._range_init_params)
        self._range_initializer.run(model)

    def _run_batchnorm_adaptation(self, model: tf.keras.Model) -> None:
        if self._bn_adaptation is None:
            self._bn_adaptation = BatchnormAdaptationAlgorithm(
                **extract_bn_adaptation_init_params(self.config, self.name))
        self._bn_adaptation.run(model)

    def _get_quantizer_setup(self,
                             model: tf.keras.Model) -> TFQuantizationSetup:
        converter = TFModelConverterFactory.create(model)
        nncf_graph = converter.convert()
        nodes = nncf_graph.get_all_nodes()
        for node in nodes:
            if node.metatype in NOT_SUPPORT_LAYER_METATYPES:
                logger.warning(
                    'The layer {} is not supported by the quantization algorithm'
                    .format(
                        get_original_name_and_instance_idx(node.node_name)[0]))

        quantizable_weighted_layer_nodes = self._get_quantizable_weighted_layer_nodes(
            nncf_graph)
        custom_layer_nodes = self._get_custom_layer_node_names(
            nncf_graph, converter)

        quantizer_setup = self._get_quantizer_propagation_solution(
            nncf_graph, quantizable_weighted_layer_nodes, custom_layer_nodes,
            model)
        setup = TFQuantizationSetup()

        quantized_layer_names_vs_qconfigs = {
        }  # type: Dict[str, QuantizerConfig]
        qp_id_to_index = {}  # type: Dict[QuantizationPointId, int]
        tf_setup_qp_index = 0
        applied_overflow_fix = False
        first_conv_nodes = get_first_nodes_of_type(nncf_graph, ['Conv2D'])
        for qp_id, qp in quantizer_setup.quantization_points.items():
            if qp.is_weight_quantization_point():
                target_node = nncf_graph.get_node_by_name(
                    qp.insertion_point.target_node_name)
                is_custom, layer_info = converter.get_layer_info_for_node(
                    target_node.node_name)
                if is_custom:
                    raise RuntimeError(
                        "Quantizing custom layer weights is currently unsupported!"
                    )
                layer_name = layer_info.layer_name
                qconfig = qp.qconfig
                if layer_name in quantized_layer_names_vs_qconfigs:
                    assigned_qconfig = quantized_layer_names_vs_qconfigs[
                        layer_name]
                    if qconfig != assigned_qconfig:
                        raise RuntimeError(
                            f"Inconsistent quantizer configurations selected by solver for one and the "
                            f"same quantizable layer! Tried to assign {qconfig} to {layer_name} as "
                            f"specified by QP {qp_id}, but the layer already has quantizer "
                            f"config {assigned_qconfig} assigned to it!")
                    continue  # The layer has already been quantized
                quantized_layer_names_vs_qconfigs[layer_name] = qconfig
                metatype = target_node.metatype
                assert issubclass(metatype, TFLayerWithWeightsMetatype)
                for weight_def in metatype.weight_definitions:
                    op_name = self._get_quantizer_operation_name(
                        target_node.node_name, weight_def.weight_attr_name)
                    self._op_names.append(op_name)

                    half_range = self._get_half_range(qconfig, target_node,
                                                      first_conv_nodes)
                    applied_overflow_fix = applied_overflow_fix or half_range
                    quantizer_spec = TFQuantizerSpec.from_config(
                        qconfig,
                        narrow_range=not half_range,
                        half_range=half_range)
                    target_point = TFLayerWeight(layer_info.layer_name,
                                                 weight_def.weight_attr_name)
                    qpoint = TFQuantizationPoint(op_name, quantizer_spec,
                                                 target_point)
            else:
                assert qp.is_activation_quantization_point()
                ip = qp.insertion_point
                assert isinstance(ip, ActivationQuantizationInsertionPoint)
                target_node_name = ip.target_node_name
                input_port_id = ip.input_port_id
                fake_quantize_name = self._get_fake_quantize_name(
                    target_node_name, input_port_id)
                quantizer_spec = TFQuantizerSpec.from_config(
                    qp.qconfig, narrow_range=False, half_range=False)
                fake_quantize_layer = FakeQuantize(quantizer_spec,
                                                   name=fake_quantize_name)
                self._op_names.append(fake_quantize_layer.op_name)

                is_custom, layer_info = converter.get_layer_info_for_node(
                    target_node_name)
                if is_custom:
                    raise RuntimeError(
                        "Quantizing custom layer activations is currently unsupported!"
                    )
                if input_port_id is not None:
                    target_point = TFBeforeLayer(
                        layer_info.layer_name,
                        instance_idx=layer_info.instance_idx,
                        input_port_id=input_port_id)
                else:
                    target_point = TFAfterLayer(
                        layer_info.layer_name,
                        instance_idx=layer_info.instance_idx,
                        output_port_id=0)
                qpoint = TFQuantizationPoint(fake_quantize_name,
                                             quantizer_spec, target_point)

            setup.add_quantization_point(qpoint)
            qp_id_to_index[qp_id] = tf_setup_qp_index
            tf_setup_qp_index += 1

        setup = self._generate_unified_scale_groups(model, quantizer_setup,
                                                    qp_id_to_index, setup)

        self._raise_overflow_fix_warning(applied_overflow_fix)

        return setup

    def _raise_overflow_fix_warning(self, applied_overflow_fix: bool):
        if applied_overflow_fix:
            if self._overflow_fix == 'enable':
                quantizers_with_overflow_fix_str = 'all weight quantizers'
            elif self._overflow_fix == 'first_layer_only':
                quantizers_with_overflow_fix_str = 'first convolution weight quantizers'
            logger.warning(
                'The overflow issue fix will be applied. '
                'Now {} will effectively use only 7 bits out of '
                '8 bits. This resolves the overflow issue problem on AVX2 and AVX-512 machines. '
                'Please take a look at the documentation for a detailed information.'
                .format(quantizers_with_overflow_fix_str))

    def _generate_unified_scale_groups(
            self, model: tf.keras.Model,
            quantizer_setup: SingleConfigQuantizerSetup,
            qp_id_to_index: Dict[QuantizationPointId, int],
            setup: TFQuantizationSetup) -> TFQuantizationSetup:
        # To properly set the instance indices for FQ need to save layers order like in the model config
        layer_names = [layer.name for layer in model.layers]
        for unified_group in quantizer_setup.unified_scale_groups.values():
            sorted_unified_group = []
            for qp_id in unified_group:
                qp = quantizer_setup.quantization_points[qp_id]
                qp_layer_name = qp.insertion_point.target_node_name
                original_name, _ = get_original_name_and_instance_idx(
                    qp_layer_name)
                layer_idx = layer_names.index(original_name)
                tf_setup_index = qp_id_to_index[qp_id]
                sorted_unified_group.append((tf_setup_index, layer_idx))

            sorted_unified_group = sorted(sorted_unified_group,
                                          key=lambda x: x[1])
            setup.register_unified_scale_group(
                [setup_index for setup_index, _ in sorted_unified_group])
        return setup

    def _get_quantizable_weighted_layer_nodes(
            self, nncf_graph: NNCFGraph) -> List[QuantizableWeightedLayerNode]:
        nodes_with_weights = []
        for node in nncf_graph.get_all_nodes():
            metatype = node.metatype
            if metatype in OUTPUT_NOOP_METATYPES:
                continue

            if not (metatype in QUANTIZATION_LAYER_METATYPES
                    and should_consider_scope(
                        node.node_name,
                        ignored_scopes=self.ignored_scopes_per_group[
                            QuantizerGroup.WEIGHTS],
                        target_scopes=None)):
                continue

            assert issubclass(metatype, TFLayerWithWeightsMetatype)
            nodes_with_weights.append(node)
        scope_overrides_dict = self._get_algo_specific_config_section().get(
            'scope_overrides', {})
        weighted_node_and_qconf_lists = assign_qconfig_lists_to_modules(
            nodes_with_weights,
            self.DEFAULT_QCONFIG,
            self.global_quantizer_constraints[QuantizerGroup.WEIGHTS],
            scope_overrides_dict,
            hw_config=self.hw_config)
        return [
            QuantizableWeightedLayerNode(node, qconf_list)
            for node, qconf_list in weighted_node_and_qconf_lists.items()
        ]

    def _get_quantizer_propagation_solution(self, nncf_graph: NNCFGraph,
                                            quantizable_weighted_layer_nodes: List[QuantizableWeightedLayerNode],
                                            custom_layer_node_names: List[NNCFNodeName],
                                            model: tf.keras.Model) \
            -> SingleConfigQuantizerSetup:
        ip_graph = InsertionPointGraph(
            nncf_graph,
            [qn.node.node_name for qn in quantizable_weighted_layer_nodes])

        pattern = TF_HW_FUSED_PATTERNS.get_full_pattern_graph()
        ip_graph = ip_graph.get_ip_graph_with_merged_hw_optimized_operations(
            pattern)

        input_preprocessing_nodes = self._get_input_preprocessing_nodes(
            nncf_graph, model)
        input_preprocessing_node_names = [
            n.node_name for n in input_preprocessing_nodes
        ]
        if custom_layer_node_names:
            logger.warning(
                'Custom layers [{}] '
                'will be ignored during quantization since it is not yet supported in NNCF'
                .format(", ".join([str(l) for l in custom_layer_node_names])))
        ignored_scopes_for_solver = self.ignored_scopes_per_group[QuantizerGroup.ACTIVATIONS] + \
                                    input_preprocessing_node_names + custom_layer_node_names

        solver = QuantizerPropagationSolver(
            ignored_scopes=ignored_scopes_for_solver,
            target_scopes=self.target_scopes_per_group[
                QuantizerGroup.ACTIVATIONS],
            hw_config=self.hw_config,
            default_trait_to_metatype_map=DEFAULT_TF_QUANT_TRAIT_TO_OP_DICT,
            default_qconfig_list=[
                self._get_default_qconfig(self.global_quantizer_constraints[
                    QuantizerGroup.ACTIVATIONS])
            ],
            quantizable_layer_nodes=quantizable_weighted_layer_nodes,
            global_constraints=self.global_quantizer_constraints,
            quantize_outputs=self.quantize_outputs)

        quantization_proposal = solver.run_on_ip_graph(ip_graph)
        multi_config_setup = quantization_proposal.quantizer_setup
        single_config_setup = multi_config_setup.select_first_qconfig_for_each_point(
        )
        finalized_proposal = quantization_proposal.finalize(
            single_config_setup)
        final_setup = solver.get_final_quantizer_setup(finalized_proposal)
        final_setup = self._handle_quantize_inputs_option(
            final_setup, nncf_graph)

        return final_setup

    def _handle_quantize_inputs_option(
            self, quantizer_setup: SingleConfigQuantizerSetup,
            nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
        qp_ids_to_discard = []
        for qp_id, qp in quantizer_setup.quantization_points.items():
            if qp.is_activation_quantization_point():
                insertion_point = qp.insertion_point
                target_node = nncf_graph.get_node_by_name(
                    insertion_point.target_node_name)
                if not self.quantize_inputs and target_node.metatype in INPUT_NOOP_METATYPES:
                    qp_ids_to_discard.append(qp_id)
        for qp_id in qp_ids_to_discard:
            quantizer_setup.discard(qp_id, keep_shared_input_qps=True)
        return quantizer_setup

    def _get_input_preprocessing_nodes(
            self, nncf_graph: NNCFGraph,
            model: tf.keras.Model) -> List[NNCFNode]:
        retval = []

        def traverse_fn(
            node: NNCFNode, preprocessing_nodes: List[NNCFNode]
        ) -> Tuple[bool, List[NNCFNode]]:
            is_finished = True
            successors = nncf_graph.get_next_nodes(node)
            if len(successors) == 1:
                successor = next(iter(successors))
                # It is necessary to determine the number of input nodes from the model
                # in order to correctly count the duplicated edges
                original_name, _ = get_original_name_and_instance_idx(
                    successor.node_name)
                layer = model.get_layer(name=original_name)
                num_previous_nodes = len(layer.input) if isinstance(
                    layer.input, list) else 1
                if successor.metatype in ELEMENTWISE_LAYER_METATYPES and num_previous_nodes == 1:
                    preprocessing_nodes.append(successor)
                    is_finished = False
            return is_finished, preprocessing_nodes

        for nncf_node in nncf_graph.get_input_nodes():
            preprocessing_nodes_for_this_input = nncf_graph.traverse_graph(
                nncf_node, traverse_fn)
            retval += preprocessing_nodes_for_this_input

        return retval

    def _get_quantized_nodes_for_output(
            self,
            nncf_graph: NNCFGraph,
            insertion_points: List[str],
            node_key: str,
            quantized_nodes_for_output: List[NNCFNode] = None
    ) -> List[NNCFNode]:
        nncf_node = nncf_graph.get_node_by_key(node_key)
        if quantized_nodes_for_output is None:
            if node_key in insertion_points:
                return [nncf_node]
            quantized_nodes_for_output = []

        for predecessor in nncf_graph.get_previous_nodes(nncf_node):
            pred_node_key = nncf_graph.get_node_key_by_id(predecessor.node_id)
            if len(nncf_graph.get_next_nodes(predecessor)) > 1:
                logger.warning(
                    'Removing of FakeQuantize after layer {} '
                    'with multiple outputs is not fully supported'.format(
                        predecessor.node_name))
            if predecessor.metatype in LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION:
                self._get_quantized_nodes_for_output(
                    nncf_graph, insertion_points, pred_node_key,
                    quantized_nodes_for_output)
            elif nncf_graph.get_node_key_by_id(
                    predecessor.node_id) in insertion_points:
                quantized_nodes_for_output.append(predecessor)
        return quantized_nodes_for_output

    def _get_fake_quantize_name(self,
                                node_name: NNCFNodeName,
                                input_port_id: int = None) -> str:
        original_node_name, instance_idx = get_original_name_and_instance_idx(
            node_name)
        fq_name = '{}/fake_quantize'.format(original_node_name)
        if instance_idx != 0:
            fq_name += f"_{instance_idx}"
        if input_port_id is not None:
            fq_name += f"_I{input_port_id}"
        return fq_name

    def _get_quantizer_operation_name(self, layer_name, weight_attr_name):
        return f'{layer_name}_{weight_attr_name}_quantizer'
Ejemplo n.º 7
0
 def _run_batchnorm_adaptation(self):
     if self._bn_adaptation is None:
         self._bn_adaptation = BatchnormAdaptationAlgorithm(
             **extract_bn_adaptation_init_params(self.config,
                                                 'filter_pruning'))
     self._bn_adaptation.run(self.model)
Ejemplo n.º 8
0
class FilterPruningController(BasePruningAlgoController):
    def __init__(self, target_model: NNCFNetwork, prunable_types: List[str],
                 pruned_module_groups: Clusterization[PrunedModuleInfo],
                 pruned_norms_operators: List[Tuple[NNCFNode,
                                                    FilterPruningMask,
                                                    torch.nn.Module]],
                 config: NNCFConfig):
        #pylint:disable=too-many-statements
        super().__init__(target_model, prunable_types, pruned_module_groups,
                         config)
        params = self.pruning_config.get('params', {})
        self._pruned_norms_operators = pruned_norms_operators
        self.frozen = False
        self._pruning_level = 0
        self.pruning_init = self.pruning_config.get('pruning_init', 0)
        self.pruning_quota = 0.9
        self.normalize_weights = True

        self._init_module_channels_and_shapes()
        self.pruning_quotas = {}
        self.nodes_flops = {}  # type: Dict[NNCFNodeName, int]
        self.nodes_params_num = {}  # type: Dict[NNCFNodeName, int]
        self.next_nodes = {}  # type: Dict[int, List[NNCFNodeName]]
        self._init_pruned_modules_params()
        self.flops_count_init()
        self.full_flops = sum(self.nodes_flops.values())
        self.current_flops = self.full_flops
        self.full_params_num = sum(self.nodes_params_num.values())
        self.current_params_num = self.full_params_num
        self.full_filters_num = count_filters_num(
            self._model.get_original_graph(), GENERAL_CONV_LAYER_METATYPES)
        self.current_filters_num = self.full_filters_num
        self._pruned_layers_num = len(
            self.pruned_module_groups_info.get_all_nodes())
        self._prunable_layers_num = len(
            self._model.get_graph().get_nodes_by_types(self._prunable_types))
        self._max_prunable_flops, self._max_prunable_params =\
            self._calculate_flops_and_weights_in_uniformly_pruned_model(1.)

        self.weights_normalizer = tensor_l2_normalizer  # for all weights in common case
        self.filter_importance = FILTER_IMPORTANCE_FUNCTIONS.get(
            params.get('filter_importance', 'L2'))
        self.ranking_type = params.get('interlayer_ranking_type',
                                       'unweighted_ranking')
        self.all_weights = params.get("all_weights", False)
        scheduler_cls = PRUNING_SCHEDULERS.get(
            params.get('schedule', 'exponential'))
        self._scheduler = scheduler_cls(self, params)

        if self.ranking_type == 'learned_ranking':
            # In case of learned_ranking ranking type weights shouldn't be normalized
            self.normalize_weights = False
            if params.get('load_ranking_coeffs_path'):
                coeffs_path = params.get('load_ranking_coeffs_path')
                nncf_logger.info(
                    'Loading ranking coefficients from file {}'.format(
                        coeffs_path))
                try:
                    with open(coeffs_path, 'r',
                              encoding='utf8') as coeffs_file:
                        loaded_coeffs = json.load(coeffs_file)
                except (ValueError, FileNotFoundError) as err:
                    raise Exception(
                        'Can\'t load json with ranking coefficients. Please, check format of json file '
                        'and path to the file.') from err
                ranking_coeffs = {
                    key: tuple(loaded_coeffs[key])
                    for key in loaded_coeffs
                }
                nncf_logger.info(
                    'Loaded ranking coefficients = {}'.format(ranking_coeffs))
                self.ranking_coeffs = ranking_coeffs
            else:
                # Ranking can't be trained without registered init struct LeGRInitArgs
                if not config.has_extra_struct(LeGRInitArgs):
                    raise Exception(
                        'Please, register LeGRInitArgs via register_default_init_args function.'
                    )
                # Wrapping model for parallelization
                distributed_wrapping_init_args = config.get_extra_struct(
                    DistributedCallbacksArgs)
                target_model = distributed_wrapping_init_args.wrap_model(
                    target_model)
                legr_init_args = config.get_extra_struct(LeGRInitArgs)
                legr_params = params.get("legr_params", {})
                if 'max_pruning' not in legr_params:
                    legr_params['max_pruning'] = self._scheduler.target_level
                self.legr = LeGR(self, target_model, legr_init_args,
                                 **legr_params)
                self.ranking_coeffs = self.legr.train_global_ranking()
                nncf_logger.info('Trained ranking coefficients = {}'.format(
                    self.ranking_coeffs))
                # Unwrapping parallelized model
                target_model = distributed_wrapping_init_args.unwrap_model(
                    target_model)
        else:
            self.ranking_coeffs = {
                node.node_name: (1, 0)
                for node in self.pruned_module_groups_info.get_all_nodes()
            }

        # Saving ranking coefficients to the specified file
        if params.get('save_ranking_coeffs_path'):
            nncf_logger.info(
                'Saving ranking coefficients to the file {}'.format(
                    params.get('save_ranking_coeffs_path')))
            with open(params.get('save_ranking_coeffs_path'),
                      'w',
                      encoding='utf8') as f:
                json.dump(self.ranking_coeffs, f)

        self.set_pruning_level(self.pruning_init)
        self._bn_adaptation = None

    @property
    def loss(self) -> CompressionLoss:
        return self._loss

    @property
    def scheduler(self) -> PruningScheduler:
        return self._scheduler

    @staticmethod
    def get_mask(minfo: PrunedModuleInfo) -> torch.Tensor:
        return minfo.operand.binary_filter_pruning_mask

    @staticmethod
    def set_mask(minfo: PrunedModuleInfo, mask: torch.Tensor) -> None:
        minfo.operand.binary_filter_pruning_mask = mask

    def statistics(self,
                   quickly_collected_only: bool = False) -> NNCFStatistics:
        if not quickly_collected_only and is_debug():
            stats = PrunedModelTheoreticalBorderline(self._pruned_layers_num,
                                                     self._prunable_layers_num,
                                                     self._max_prunable_flops,
                                                     self._max_prunable_params,
                                                     self.full_flops,
                                                     self.full_params_num)

            nncf_logger.debug(stats.to_str())

        pruned_layers_summary = {}
        for minfo in self.pruned_module_groups_info.get_all_nodes():
            layer_name = str(minfo.module_scope)
            if layer_name not in pruned_layers_summary:
                pruned_layers_summary[layer_name] = \
                    PrunedLayerSummary(layer_name,
                                       list(minfo.module.weight.size()),
                                       list(self.mask_shape(minfo)),
                                       self.pruning_level_for_mask(minfo))

        self._update_benchmark_statistics()
        model_statistics = PrunedModelStatistics(
            self.full_flops, self.current_flops, self.full_params_num,
            self.current_params_num, self.full_filters_num,
            self.current_filters_num, list(pruned_layers_summary.values()))

        stats = FilterPruningStatistics(model_statistics,
                                        self.scheduler.current_pruning_level,
                                        self.scheduler.target_level,
                                        self.prune_flops)

        nncf_stats = NNCFStatistics()
        nncf_stats.register('filter_pruning', stats)
        return nncf_stats

    @property
    def pruning_level(self) -> float:
        """Global pruning level in the model"""
        return self._pruning_level

    def freeze(self, freeze: bool = True):
        self.frozen = freeze

    def _init_module_channels_and_shapes(self):
        self._modules_in_channels = {}  # type: Dict[NNCFNodeName, int]
        self._modules_out_channels = {}  # type: Dict[NNCFNodeName, int]
        self._modules_in_shapes = {}  # type: Dict[NNCFNodeName, List[int]]
        self._modules_out_shapes = {}  # type: Dict[NNCFNodeName, List[int]]

    def _init_pruned_modules_params(self):
        # 1. Init in/out channels for potentially prunable modules
        graph = self._model.get_original_graph()
        self._modules_in_channels, self._modules_out_channels = get_conv_in_out_channels(
            graph)

        # 2. Init next_nodes for every pruning cluster
        self.next_nodes = get_cluster_next_nodes(
            graph, self.pruned_module_groups_info, self._prunable_types)

        # 3. Init pruning quotas
        for cluster in self.pruned_module_groups_info.get_all_clusters():
            self.pruning_quotas[cluster.id] = np.floor(self._modules_out_channels[cluster.elements[0].node_name] \
                                                       * self.pruning_quota)

    def _calculate_output_shape(self, graph: NNCFGraph,
                                node: NNCFNode) -> Tuple[int, ...]:
        """
        Calculates output shape of convolution layer by input edge.

        :param graph: the model graph
        :param node: node from NNCF graph
        :return: output shape
        """
        in_edge = graph.get_input_edges(node)[0]
        shape = list(in_edge.tensor_shape)[2:]
        attrs = node.layer_attributes

        assert isinstance(attrs, ConvolutionLayerAttributes)

        for i, _ in enumerate(shape):
            if attrs.transpose:
                shape[i] = (shape[i] - 1) * attrs.stride[
                    i] - 2 * attrs.padding_values[i] + attrs.kernel_size[i]
            else:
                shape[i] = (shape[i] + 2 * attrs.padding_values[i] -
                            attrs.kernel_size[i]) // attrs.stride[i] + 1
        return tuple(shape)

    def flops_count_init(self) -> None:
        graph = self._model.get_original_graph()
        for node in graph.get_nodes_by_types(
            [v.op_func_name for v in NNCF_GENERAL_CONV_MODULES_DICT]):
            output_edges = graph.get_output_edges(node)
            if output_edges:
                out_edge = output_edges[0]
                out_shape = out_edge.tensor_shape[2:]
            else:
                # For disconnected NNCFGraph when node have no output edge
                out_shape = self._calculate_output_shape(graph, node)
                nncf_logger.error("Node %s have no output edge in NNCFGraph",
                                  node.node_name)
            self._modules_out_shapes[node.node_name] = out_shape

        for node in graph.get_nodes_by_types(
            [v.op_func_name for v in NNCF_LINEAR_MODULES_DICT]):
            output_edges = graph.get_output_edges(node)
            if output_edges:
                out_edge = graph.get_output_edges(node)[0]
                out_shape = out_edge.tensor_shape
                self._modules_out_shapes[node.node_name] = out_shape[-1]
            else:
                # For disconnected NNCFGraph when node have no output edge
                nncf_logger.error("Node %s have no output edge in NNCFGraph",
                                  node.node_name)
                self._modules_out_shapes[
                    node.node_name] = node.layer_attributes.out_features

            in_edge = graph.get_input_edges(node)[0]
            in_shape = in_edge.tensor_shape
            if len(in_shape) == 1:
                self._modules_in_shapes[node.node_name] = in_shape[0]
            else:
                self._modules_in_shapes[node.node_name] = in_shape[1:]

        self.nodes_flops, self.nodes_params_num = \
            count_flops_and_weights_per_node(graph, self._modules_in_shapes, self._modules_out_shapes,
                                             conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
                                             linear_op_metatypes=LINEAR_LAYER_METATYPES)

    def _calculate_flops_and_weights_in_uniformly_pruned_model(
            self, pruning_level: float) -> Tuple[int, int]:
        """
        Prune all prunable modules in model by pruning_level level and returns number of weights and
        flops of the pruned model.

        :param pruning_level: proportion of zero filters in all modules
        :return: flops number in pruned model
        """
        tmp_in_channels, tmp_out_channels = \
            calculate_in_out_channels_in_uniformly_pruned_model(
                pruning_groups=self.pruned_module_groups_info.get_all_clusters(),
                pruning_level=pruning_level,
                full_input_channels=self._modules_in_channels,
                full_output_channels=self._modules_out_channels,
                pruning_groups_next_nodes=self.next_nodes)

        return count_flops_and_weights(
            self._model.get_original_graph(),
            self._modules_in_shapes,
            self._modules_out_shapes,
            input_channels=tmp_in_channels,
            output_channels=tmp_out_channels,
            conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
            linear_op_metatypes=LINEAR_LAYER_METATYPES)

    def _find_uniform_pruning_level_for_target_flops(
            self, target_flops_pruning_level: float) -> float:
        """
        Searching for the minimal uniform layer-wise weight pruning level (proportion of zero filters in a layer)
         needed to achieve the target pruning level in flops.

        :param target_flops_pruning_level: target proportion of flops that should be pruned in the model
        :return: uniform pruning level for all layers
        """
        error = 0.01
        target_flops = self.full_flops * (1 - target_flops_pruning_level)
        left, right = 0.0, 1.0
        while abs(right - left) > error:
            middle = (left + right) / 2
            flops, params_num = self._calculate_flops_and_weights_in_uniformly_pruned_model(
                middle)
            if flops < target_flops:
                right = middle
            else:
                left = middle
        flops, params_num = self._calculate_flops_and_weights_in_uniformly_pruned_model(
            right)
        if flops < target_flops:
            self.current_flops = flops
            self.current_params_num = params_num
            return right
        raise RuntimeError(
            "Can't prune the model to get the required "
            "pruning level in flops = {}".format(target_flops_pruning_level))

    def set_pruning_level(self,
                          pruning_level: Union[float, Dict[int, float]],
                          run_batchnorm_adaptation: bool = False) -> None:
        """
        Set the global or groupwise pruning level in the model.
        If pruning_level is a float, the correspoding global pruning level is set in the model,
        either in terms of the percentage of filters pruned or as the percentage of flops
        removed, the latter being true in case the "prune_flops" flag of the controller is
        set to True.
        If pruning_level is a dict, the keys should correspond to layer group id's and the
        values to groupwise pruning level to be set in the model.
        """
        groupwise_pruning_levels_set = isinstance(pruning_level, dict)
        passed_pruning_level = pruning_level

        if not self.frozen:
            nncf_logger.info(
                "Computing filter importance scores and binary masks...")
            with torch.no_grad():
                if self.all_weights:
                    if groupwise_pruning_levels_set:
                        raise RuntimeError(
                            'Cannot set group-wise pruning levels with '
                            'all_weights=True')
                    # Non-uniform (global) importance-score-based pruning according
                    # to the global pruning level
                    if self.prune_flops:
                        self._set_binary_masks_for_pruned_modules_globally_by_flops_target(
                            pruning_level)
                    else:
                        self._set_binary_masks_for_pruned_modules_globally(
                            pruning_level)
                else:
                    if groupwise_pruning_levels_set:
                        group_ids = [
                            group.id for group in
                            self.pruned_module_groups_info.get_all_clusters()
                        ]
                        if set(pruning_level.keys()) != set(group_ids):
                            raise RuntimeError(
                                'Groupwise pruning level dict keys do not correspond to '
                                'layer group ids')
                    else:
                        # Pruning uniformly with the same pruning level across layers
                        if self.prune_flops:
                            # Looking for layerwise pruning level needed for the required flops pruning level
                            pruning_level = self._find_uniform_pruning_level_for_target_flops(
                                pruning_level)
                    self._set_binary_masks_for_pruned_modules_groupwise(
                        pruning_level)

        self._propagate_masks()
        if not groupwise_pruning_levels_set:
            self._pruning_level = passed_pruning_level
        else:
            self._pruning_level = self._calculate_global_weight_pruning_level()

        if run_batchnorm_adaptation:
            self._run_batchnorm_adaptation()

    def _calculate_global_weight_pruning_level(self) -> float:
        full_param_count = 0
        pruned_param_count = 0
        for minfo in self.pruned_module_groups_info.get_all_nodes():
            layer_param_count = sum(p.numel()
                                    for p in minfo.module.parameters()
                                    if p.requires_grad)
            layer_weight_pruning_level = self.pruning_level_for_mask(minfo)
            full_param_count += layer_param_count
            pruned_param_count += layer_param_count * layer_weight_pruning_level
        return pruned_param_count / full_param_count

    @property
    def current_groupwise_pruning_level(self) -> Dict[int, float]:
        """
        Return the dict of layer group id's and corresponding current groupwise
        pruning levels in the model
        """
        groupwise_pruning_level_dict = {}
        for group in self.pruned_module_groups_info.get_all_clusters():
            groupwise_pruning_level_dict[
                group.id] = self.pruning_level_for_mask(group.elements[0])
        return groupwise_pruning_level_dict

    def _set_binary_masks_for_pruned_modules_groupwise(
            self, pruning_level: Union[float, Dict[int, float]]) -> None:
        """
        Set the binary mask values according to groupwise pruning level.
        If pruning_level is a float, set the pruning level uniformly across groups.
        If pruning_level is a dict, set specific pruning levels corresponding to each group.
        """
        nncf_logger.debug("Updating binary masks for pruned modules.")
        groupwise_pruning_levels_set = isinstance(pruning_level, dict)

        for group in self.pruned_module_groups_info.get_all_clusters():
            group_pruning_level = pruning_level[group.id] if groupwise_pruning_levels_set \
                else pruning_level

            filters_num = torch.tensor(
                [get_filters_num(minfo.module) for minfo in group.elements])
            assert torch.all(filters_num == filters_num[0])
            device = group.elements[0].module.weight.device

            cumulative_filters_importance = torch.zeros(
                filters_num[0]).to(device)
            # 1. Calculate cumulative importance for all filters in group
            for minfo in group.elements:
                filters_importance = self.filter_importance(
                    minfo.module.weight,
                    minfo.module.target_weight_dim_for_compression)
                cumulative_filters_importance += filters_importance

            # 2. Calculate threshold
            num_of_sparse_elems = get_rounded_pruned_element_number(
                cumulative_filters_importance.size(0), group_pruning_level)
            threshold = sorted(cumulative_filters_importance)[min(
                num_of_sparse_elems, filters_num[0] - 1)]
            mask = calculate_binary_mask(cumulative_filters_importance,
                                         threshold)

            # 3. Set binary masks for filter
            for minfo in group.elements:
                pruning_module = minfo.operand
                pruning_module.binary_filter_pruning_mask = mask

        # Calculate actual flops and weights number with new masks
        self._update_benchmark_statistics()

    def _set_binary_masks_for_pruned_modules_globally(
            self, pruning_level: float) -> None:
        """
        Set the binary mask values for layer groups according to the global pruning level.
        Filter importance scores in each group are merged into a single global list and a
        threshold value separating the pruning_level proportion of the least important filters
        in the model is calculated. Filters are pruned globally according to the threshold value.
        """
        nncf_logger.debug(
            "Setting new binary masks for all pruned modules together.")
        filter_importances = []
        # 1. Calculate importances for all groups of  filters
        for group in self.pruned_module_groups_info.get_all_clusters():
            filters_num = torch.tensor(
                [get_filters_num(minfo.module) for minfo in group.elements])
            assert torch.all(filters_num == filters_num[0])
            device = group.elements[0].module.weight.device

            cumulative_filters_importance = torch.zeros(
                filters_num[0]).to(device)
            # Calculate cumulative importance for all filters in this group
            for minfo in group.elements:
                normalized_weight = self.weights_normalizer(
                    minfo.module.weight)
                filters_importance = self.filter_importance(
                    normalized_weight,
                    minfo.module.target_weight_dim_for_compression)
                cumulative_filters_importance += filters_importance

            filter_importances.append(cumulative_filters_importance)

        # 2. Calculate one threshold for all weights
        importances = torch.cat(filter_importances)
        threshold = sorted(importances)[int(pruning_level *
                                            importances.size(0))]

        # 3. Set binary masks for filters in groups
        for i, group in enumerate(
                self.pruned_module_groups_info.get_all_clusters()):
            mask = calculate_binary_mask(filter_importances[i], threshold)
            for minfo in group.elements:
                pruning_module = minfo.operand
                pruning_module.binary_filter_pruning_mask = mask

        # Calculate actual flops and weights number with new masks
        self._update_benchmark_statistics()

    def _set_binary_masks_for_pruned_modules_globally_by_flops_target(
            self, target_flops_pruning_level: float) -> None:
        """
        Sorting all prunable filters in the network by importance and pruning the amount of the
        least important filters sufficient to achieve the target pruning level by flops.
        Filters are pruned one-by-one and the corresponding flops value is checked.

        :param target_flops_pruning_level: target proportion of flops removed from the model
        :return:
        """
        target_flops = self.full_flops * (1 - target_flops_pruning_level)

        # 1. Initialize masks
        for minfo in self.pruned_module_groups_info.get_all_nodes():
            new_mask = torch.ones(get_filters_num(minfo.module)).to(
                minfo.module.weight.device)
            self.set_mask(minfo, new_mask)

        # 2. Calculate filter importances for all prunable groups
        filter_importances = []
        cluster_indexes = []
        filter_indexes = []

        for cluster in self.pruned_module_groups_info.get_all_clusters():
            filters_num = torch.tensor(
                [get_filters_num(minfo.module) for minfo in cluster.elements])
            assert torch.all(filters_num == filters_num[0])
            device = cluster.elements[0].module.weight.device

            cumulative_filters_importance = torch.zeros(
                filters_num[0]).to(device)
            # Calculate cumulative importance for all filters in this group
            for minfo in cluster.elements:
                weight = minfo.module.weight
                if self.normalize_weights:
                    weight = self.weights_normalizer(weight)
                filters_importance = self.filter_importance(
                    weight, minfo.module.target_weight_dim_for_compression)
                scaled_importance = self.ranking_coeffs[minfo.node_name][0] * filters_importance + \
                                    self.ranking_coeffs[minfo.node_name][1]
                cumulative_filters_importance += scaled_importance

            filter_importances.append(cumulative_filters_importance)
            cluster_indexes.append(
                cluster.id * torch.ones_like(cumulative_filters_importance))
            filter_indexes.append(
                torch.arange(len(cumulative_filters_importance)))

        importances = torch.cat(filter_importances)
        cluster_indexes = torch.cat(cluster_indexes)
        filter_indexes = torch.cat(filter_indexes)

        # 3. Sort all filter groups by importances and prune the least important filters
        # until target flops pruning level is achieved
        sorted_importances = sorted(zip(importances, cluster_indexes,
                                        filter_indexes),
                                    key=lambda x: x[0])
        cur_num = 0
        tmp_in_channels = self._modules_in_channels.copy()
        tmp_out_channels = self._modules_out_channels.copy()
        tmp_pruning_quotas = self.pruning_quotas.copy()

        while cur_num < len(sorted_importances):
            cluster_idx = int(sorted_importances[cur_num][1])
            filter_idx = int(sorted_importances[cur_num][2])

            if tmp_pruning_quotas[cluster_idx] > 0:
                tmp_pruning_quotas[cluster_idx] -= 1
            else:
                cur_num += 1
                continue

            cluster = self.pruned_module_groups_info.get_cluster_by_id(
                cluster_idx)
            for node in cluster.elements:
                tmp_out_channels[node.node_name] -= 1
                if node.is_depthwise:
                    tmp_in_channels[node.node_name] -= 1

                node.operand.binary_filter_pruning_mask[filter_idx] = 0

            # Prune in channels in all next nodes
            next_nodes = self.next_nodes[cluster.id]
            for node_id in next_nodes:
                tmp_in_channels[node_id] -= 1

            flops, params_num = count_flops_and_weights(
                self._model.get_original_graph(),
                self._modules_in_shapes,
                self._modules_out_shapes,
                input_channels=tmp_in_channels,
                output_channels=tmp_out_channels,
                conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
                linear_op_metatypes=LINEAR_LAYER_METATYPES)
            if flops < target_flops:
                self.current_flops = flops
                self.current_params_num = params_num
                return
            cur_num += 1
        raise RuntimeError("Can't prune model to asked flops pruning level")

    def _propagate_masks(self):
        nncf_logger.debug("Propagating pruning masks")
        # 1. Propagate masks for all modules
        graph = self.model.get_original_graph()

        init_output_masks_in_graph(
            graph, self.pruned_module_groups_info.get_all_nodes())
        MaskPropagationAlgorithm(
            graph, PT_PRUNING_OPERATOR_METATYPES,
            PTNNCFPruningTensorProcessor).mask_propagation()

        # 2. Set the masks for Batch/Group Norms
        pruned_node_modules = []
        for node, pruning_block, node_module in self._pruned_norms_operators:
            if node_module not in pruned_node_modules:
                # Setting masks for BN nodes
                pruning_block.binary_filter_pruning_mask = node.data[
                    'output_mask'].tensor
                pruned_node_modules.append(node_module)

    def prepare_for_export(self):
        """
        Applies pruning masks to layer weights before exporting the model to ONNX.
        """
        self._propagate_masks()

        pruned_layers_stats = self.get_stats_for_pruned_modules()
        nncf_logger.debug('Pruned layers statistics: \n%s',
                          pruned_layers_stats.draw())

    def compression_stage(self) -> CompressionStage:
        target_pruning_level = self.scheduler.target_level
        actual_pruning_level = self._pruning_level
        if actual_pruning_level == 0:
            return CompressionStage.UNCOMPRESSED
        if actual_pruning_level >= target_pruning_level:
            return CompressionStage.FULLY_COMPRESSED
        return CompressionStage.PARTIALLY_COMPRESSED

    @property
    def compression_rate(self):
        if self.prune_flops:
            return 1 - self.current_flops / self.full_flops
        return self.pruning_level

    @compression_rate.setter
    def compression_rate(self, pruning_rate):
        is_pruning_controller_frozen = self.frozen
        self.freeze(False)
        self.set_pruning_level(pruning_rate)
        self.freeze(is_pruning_controller_frozen)

    def disable_scheduler(self):
        self._scheduler = StubCompressionScheduler()
        self._scheduler.current_pruning_level = 0.0

    def _collect_pruning_masks(self) -> Dict[str, PTNNCFTensor]:
        retval = {}
        for group in self.pruned_module_groups_info.get_all_clusters():
            for node in group.elements:
                retval[node.node_name] = PTNNCFTensor(
                    node.operand.binary_filter_pruning_mask)
        return retval

    def _update_benchmark_statistics(self):
        tmp_in_channels, tmp_out_channels = calculate_in_out_channels_by_masks(
            pruning_groups=self.pruned_module_groups_info.get_all_clusters(),
            masks=self._collect_pruning_masks(),
            tensor_processor=PTNNCFCollectorTensorProcessor,
            full_input_channels=self._modules_in_channels,
            full_output_channels=self._modules_out_channels,
            pruning_groups_next_nodes=self.next_nodes)

        self.current_filters_num = count_filters_num(
            self._model.get_original_graph(),
            op_metatypes=GENERAL_CONV_LAYER_METATYPES,
            output_channels=tmp_out_channels)

        self.current_flops, self.current_params_num = \
            count_flops_and_weights(self._model.get_original_graph(),
                                    self._modules_in_shapes,
                                    self._modules_out_shapes,
                                    input_channels=tmp_in_channels,
                                    output_channels=tmp_out_channels,
                                    conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
                                    linear_op_metatypes=LINEAR_LAYER_METATYPES)

    def _run_batchnorm_adaptation(self):
        if self._bn_adaptation is None:
            self._bn_adaptation = BatchnormAdaptationAlgorithm(
                **extract_bn_adaptation_init_params(self.config,
                                                    'filter_pruning'))
        self._bn_adaptation.run(self.model)
Ejemplo n.º 9
0
class FilterPruningController(BasePruningAlgoController):
    """
    Serves as a handle to the additional modules, parameters and hooks inserted
    into the original uncompressed model to enable filter pruning.
    """
    def __init__(self, target_model: tf.keras.Model, graph: NNCFGraph,
                 op_names: List[str], prunable_types: List[str],
                 pruned_layer_groups: Clusterization[PrunedLayerInfo],
                 config: NNCFConfig):
        super().__init__(target_model, op_names, prunable_types,
                         pruned_layer_groups, config)
        self._original_graph = graph
        params = self.pruning_config.get('params', {})
        self.frozen = False
        self.pruning_quota = 0.9

        self._nodes_flops = {}  # type: Dict[NNCFNodeName, int]
        self._nodes_params_num = {}  # type: Dict[NNCFNodeName, int]
        self._layers_in_channels = {}
        self._layers_out_channels = {}
        self._layers_in_shapes = {}
        self._layers_out_shapes = {}
        self._pruning_quotas = {}
        self._next_nodes = {}
        self._init_pruned_layers_params()
        self._flops_count_init()
        self.full_flops = sum(self._nodes_flops.values())
        self.current_flops = self.full_flops
        self.full_params_num = sum(self._nodes_params_num.values())
        self.current_params_num = self.full_params_num
        self.full_filters_num = count_filters_num(
            self._original_graph, GENERAL_CONV_LAYER_METATYPES)
        self.current_filters_num = self.full_filters_num
        self._pruned_layers_num = len(
            self._pruned_layer_groups_info.get_all_nodes())
        self._prunable_layers_num = len(
            self._original_graph.get_nodes_by_types(self._prunable_types))
        self._max_prunable_flops, self._max_prunable_params = \
            self._calculate_flops_and_weights_in_uniformly_pruned_model(1.)

        self._weights_normalizer = tensor_l2_normalizer  # for all weights in common case
        self._filter_importance = FILTER_IMPORTANCE_FUNCTIONS.get(
            params.get('filter_importance', 'L2'))
        self.all_weights = params.get('all_weights', False)
        scheduler_cls = PRUNING_SCHEDULERS.get(
            params.get('schedule', 'exponential'))
        self._scheduler = scheduler_cls(self, params)
        self._bn_adaptation = None
        self.set_pruning_level(self.pruning_init)
        self._loss = TFZeroCompressionLoss()

    @property
    def scheduler(self) -> PruningScheduler:
        return self._scheduler

    @property
    def loss(self) -> CompressionLoss:
        return self._loss

    @property
    def compression_rate(self) -> float:
        if self.prune_flops:
            return 1 - self.current_flops / self.full_flops
        return self.pruning_rate

    @compression_rate.setter
    def compression_rate(self, compression_rate: float) -> None:
        is_pruning_controller_frozen = self.frozen
        self.freeze(False)
        self.set_pruning_level(compression_rate)
        self.freeze(is_pruning_controller_frozen)

    def disable_scheduler(self):
        self._scheduler = StubCompressionScheduler()
        self._scheduler.current_pruning_level = 0.0

    def statistics(self,
                   quickly_collected_only: bool = False) -> NNCFStatistics:
        if not quickly_collected_only and is_debug():
            stats = PrunedModelTheoreticalBorderline(self._pruned_layers_num,
                                                     self._prunable_layers_num,
                                                     self._max_prunable_flops,
                                                     self._max_prunable_params,
                                                     self.full_flops,
                                                     self.full_params_num)

            nncf_logger.debug(stats.to_str())

        pruned_layers_summary = self._calculate_pruned_layers_summary()
        self._update_benchmark_statistics()
        model_statistics = PrunedModelStatistics(
            self.full_flops, self.current_flops, self.full_params_num,
            self.current_params_num, self.full_filters_num,
            self.current_filters_num, pruned_layers_summary)

        stats = FilterPruningStatistics(model_statistics,
                                        self.scheduler.current_pruning_level,
                                        self.scheduler.target_level,
                                        self.prune_flops)

        nncf_stats = NNCFStatistics()
        nncf_stats.register('filter_pruning', stats)
        return nncf_stats

    def freeze(self, freeze: bool = True):
        self.frozen = freeze

    def set_pruning_level(self,
                          pruning_level: float,
                          run_batchnorm_adaptation: bool = False):
        """
        Setup pruning masks in accordance to provided pruning rate
        :param pruning_level: pruning ration
        :return:
        """
        # Pruning rate from scheduler can be percentage of params that should be pruned
        self.pruning_rate = pruning_level
        if not self.frozen:
            nncf_logger.info(
                'Computing filter importance scores and binary masks...')
            if self.all_weights:
                if self.prune_flops:
                    self._set_binary_masks_for_pruned_modules_globally_by_flops_target(
                        pruning_level)
                else:
                    self._set_binary_masks_for_pruned_layers_globally(
                        pruning_level)
            else:
                if self.prune_flops:
                    # Looking for a layerwise pruning rate needed for the required flops pruning rate
                    pruning_level = self._find_uniform_pruning_level_for_target_flops(
                        pruning_level)
                self._set_binary_masks_for_pruned_layers_groupwise(
                    pruning_level)

        if run_batchnorm_adaptation:
            self._run_batchnorm_adaptation()

    def _init_pruned_layers_params(self):
        # 1. Initialize in/out channels for potentially prunable layers
        self._layers_in_channels, self._layers_out_channels = get_conv_in_out_channels(
            self._original_graph)

        # 2. Initialize next_nodes for each pruning cluster
        self._next_nodes = get_cluster_next_nodes(
            self._original_graph, self._pruned_layer_groups_info,
            self._prunable_types)

        # 3. Initialize pruning quotas
        for cluster in self._pruned_layer_groups_info.get_all_clusters():
            self._pruning_quotas[cluster.id] = floor(
                self._layers_out_channels[cluster.elements[0].node_name] *
                self.pruning_quota)

    def _flops_count_init(self):
        """
        Collects input/output shapes of convolutional and dense layers,
        calculates corresponding layerwise FLOPs
        """
        for node in self._original_graph.get_nodes_by_metatypes(
                GENERAL_CONV_LAYER_METATYPES):
            node_name, node_index = get_original_name_and_instance_idx(
                node.node_name)
            layer = self._model.get_layer(node_name)
            layer_ = unwrap_layer(layer)

            channel_axis = get_input_channel_axis(layer)
            dims_slice = slice(channel_axis - layer_.rank, channel_axis) \
                if layer.data_format == 'channels_last' else slice(channel_axis + 1, None)
            in_shape = layer.get_input_shape_at(node_index)[dims_slice]
            out_shape = layer.get_output_shape_at(node_index)[dims_slice]

            if not is_valid_shape(in_shape) or not is_valid_shape(out_shape):
                raise RuntimeError(
                    f'Input/output shape is not defined for layer `{layer.name}` '
                )

            self._layers_in_shapes[node.node_name] = in_shape
            self._layers_out_shapes[node.node_name] = out_shape

        for node in self._original_graph.get_nodes_by_metatypes(
                LINEAR_LAYER_METATYPES):
            node_name, node_index = get_original_name_and_instance_idx(
                node.node_name)
            layer = self._model.get_layer(node_name)

            in_shape = layer.get_input_shape_at(node_index)[1:]
            out_shape = layer.get_output_shape_at(node_index)[1:]

            if not is_valid_shape(in_shape) or not is_valid_shape(out_shape):
                raise RuntimeError(
                    f'Input/output shape is not defined for layer `{layer.name}` '
                )

            self._layers_in_shapes[node.node_name] = in_shape
            self._layers_out_shapes[node.node_name] = out_shape

        self._nodes_flops, self._nodes_params_num = \
            count_flops_and_weights_per_node(self._original_graph,
                                             self._layers_in_shapes,
                                             self._layers_out_shapes,
                                             conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
                                             linear_op_metatypes=LINEAR_LAYER_METATYPES)

    def _set_binary_masks_for_pruned_layers_groupwise(self,
                                                      pruning_level: float):
        nncf_logger.debug('Setting new binary masks for pruned layers.')
        wrapped_layers = collect_wrapped_layers(self._model)

        # 0. Removing masks at the elements of the NNCFGraph
        for node in self._original_graph.topological_sort():
            node.data.pop('output_mask', None)

        # 1. Calculate masks
        for group in self._pruned_layer_groups_info.get_all_clusters():
            # a. Calculate the cumulative importance for all filters in the group
            cumulative_filters_importance = self._calculate_filters_importance_in_group(
                group)
            filters_num = len(cumulative_filters_importance)

            # b. Calculate threshold
            num_of_sparse_elems = get_rounded_pruned_element_number(
                cumulative_filters_importance.shape[0], pruning_level)
            threshold = sorted(cumulative_filters_importance)[min(
                num_of_sparse_elems, filters_num - 1)]

            # c. Initialize masks
            filter_mask = calculate_binary_mask(cumulative_filters_importance,
                                                threshold)
            for node in group.elements:
                nncf_node = self._original_graph.get_node_by_id(
                    node.nncf_node_id)
                nncf_node.data['output_mask'] = TFNNCFTensor(filter_mask)

        # 2. Propagating masks across the graph
        mask_propagator = MaskPropagationAlgorithm(
            self._original_graph, TF_PRUNING_OPERATOR_METATYPES,
            TFNNCFPruningTensorProcessor)
        mask_propagator.mask_propagation()

        # 3. Apply masks to the model
        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes if layer.name == n.layer_name
            ][0]
            if nncf_node.data['output_mask'] is not None:
                self._set_operation_masks([layer],
                                          nncf_node.data['output_mask'].tensor)

        # Calculate actual flops and weights number with new masks
        self._update_benchmark_statistics()

    def _set_binary_masks_for_pruned_layers_globally(self,
                                                     pruning_level: float):
        """
        Sets the binary mask values for layer groups according to the global pruning level.
        Filter importance scores in each group are merged into a single global list and a
        threshold value separating the pruning_level proportion of the least important filters
        in the model is calculated. Filters are pruned globally according to the threshold value.
        """
        nncf_logger.debug(
            'Setting new binary masks for all pruned modules together.')
        filter_importances = {}
        wrapped_layers = collect_wrapped_layers(self._model)

        # 0. Remove masks at the elements of the NNCFGraph
        for node in self._original_graph.topological_sort():
            node.data.pop('output_mask', None)

        # 1. Calculate masks
        # a. Calculate importances for all groups of filters
        for group in self._pruned_layer_groups_info.get_all_clusters():
            cumulative_filters_importance = self._calculate_filters_importance_in_group(
                group)
            filter_importances[group.id] = cumulative_filters_importance

        # b. Calculate one threshold for all weights
        importances = tf.concat(list(filter_importances.values()), 0)
        threshold = sorted(importances)[int(pruning_level *
                                            importances.shape[0])]

        # c. Initialize masks
        for group in self._pruned_layer_groups_info.get_all_clusters():
            filter_mask = calculate_binary_mask(filter_importances[group.id],
                                                threshold)
            for node in group.elements:
                nncf_node = self._original_graph.get_node_by_id(
                    node.nncf_node_id)
                nncf_node.data['output_mask'] = TFNNCFTensor(filter_mask)

        # 2. Propagate masks across the graph
        mask_propagator = MaskPropagationAlgorithm(
            self._original_graph, TF_PRUNING_OPERATOR_METATYPES,
            TFNNCFPruningTensorProcessor)
        mask_propagator.mask_propagation()

        # 3. Apply masks to the model
        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes if layer.name == n.layer_name
            ][0]
            if nncf_node.data['output_mask'] is not None:
                self._set_operation_masks([layer],
                                          nncf_node.data['output_mask'].tensor)

        # Calculate actual flops with new masks
        self._update_benchmark_statistics()

    def _set_binary_masks_for_pruned_modules_globally_by_flops_target(
            self, target_flops_pruning_level: float):
        """
        Prunes least important filters one-by-one until target FLOPs pruning level is achieved.
        Filters are sorted by filter importance score.
        """
        nncf_logger.debug('Setting new binary masks for pruned layers.')
        target_flops = self.full_flops * (1 - target_flops_pruning_level)
        wrapped_layers = collect_wrapped_layers(self._model)
        masks = {}

        nncf_sorted_nodes = self._original_graph.topological_sort()
        for layer in wrapped_layers:
            nncf_node = [
                n for n in nncf_sorted_nodes if layer.name == n.layer_name
            ][0]
            nncf_node.data['output_mask'] = TFNNCFTensor(
                tf.ones(get_filters_num(layer)))

        # 1. Calculate importances for all groups of filters. Initialize masks.
        filter_importances = []
        group_indexes = []
        filter_indexes = []
        for group in self._pruned_layer_groups_info.get_all_clusters():
            cumulative_filters_importance = self._calculate_filters_importance_in_group(
                group)
            filter_importances.extend(cumulative_filters_importance)
            filters_num = len(cumulative_filters_importance)
            group_indexes.extend([group.id] * filters_num)
            filter_indexes.extend(range(filters_num))
            masks[group.id] = tf.ones(filters_num)

        # 2.
        tmp_in_channels = self._layers_in_channels.copy()
        tmp_out_channels = self._layers_out_channels.copy()
        sorted_importances = sorted(zip(filter_importances, group_indexes,
                                        filter_indexes),
                                    key=lambda x: x[0])
        for _, group_id, filter_index in sorted_importances:
            if self._pruning_quotas[group_id] == 0:
                continue
            masks[group_id] = tf.tensor_scatter_nd_update(
                masks[group_id], [[filter_index]], [0])
            self._pruning_quotas[group_id] -= 1

            # Update input/output shapes of pruned elements
            group = self._pruned_layer_groups_info.get_cluster_by_id(group_id)
            for node in group.elements:
                tmp_out_channels[node.node_name] -= 1
                if node.is_depthwise:
                    tmp_in_channels[node.node_name] -= 1

            for node_name in self._next_nodes[group_id]:
                tmp_in_channels[node_name] -= 1

            flops, params_num = count_flops_and_weights(
                self._original_graph,
                self._layers_in_shapes,
                self._layers_out_shapes,
                input_channels=tmp_in_channels,
                output_channels=tmp_out_channels,
                conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
                linear_op_metatypes=LINEAR_LAYER_METATYPES)
            if flops <= target_flops:
                # 3. Add masks to the graph and propagate them
                for group in self._pruned_layer_groups_info.get_all_clusters():
                    for node in group.elements:
                        nncf_node = self._original_graph.get_node_by_id(
                            node.nncf_node_id)
                        nncf_node.data['output_mask'] = TFNNCFTensor(
                            masks[group.id])

                mask_propagator = MaskPropagationAlgorithm(
                    self._original_graph, TF_PRUNING_OPERATOR_METATYPES,
                    TFNNCFPruningTensorProcessor)
                mask_propagator.mask_propagation()

                # 4. Set binary masks to the model
                self.current_flops = flops
                self.current_params_num = params_num
                nncf_sorted_nodes = self._original_graph.topological_sort()
                for layer in wrapped_layers:
                    nncf_node = [
                        n for n in nncf_sorted_nodes
                        if layer.name == n.layer_name
                    ][0]
                    if nncf_node.data['output_mask'] is not None:
                        self._set_operation_masks(
                            [layer], nncf_node.data['output_mask'].tensor)
                return
        raise RuntimeError(
            f'Unable to prune model to required flops pruning level:'
            f' {target_flops_pruning_level}')

    def _set_operation_masks(self, layers: List[NNCFWrapper], filter_mask):
        for layer in layers:
            for weight_attr, ops in layer.weights_attr_ops.items():
                weight_shape = layer.layer_weights[weight_attr].shape
                for op_name, op in ops.items():
                    if isinstance(op, BinaryMask):
                        filter_axis = get_filter_axis(layer, weight_attr)
                        broadcasted_mask = broadcast_filter_mask(
                            filter_mask, weight_shape, filter_axis)
                        layer.ops_weights[op_name]['mask'].assign(
                            broadcasted_mask)

    def _find_uniform_pruning_level_for_target_flops(
            self, target_flops_pruning_level):
        error = 0.01
        target_flops = self.full_flops * (1 - target_flops_pruning_level)
        left, right = 0.0, 1.0
        while abs(right - left) > error:
            middle = (left + right) / 2
            flops, params_num = self._calculate_flops_and_weights_in_uniformly_pruned_model(
                middle)
            if flops < target_flops:
                right = middle
            else:
                left = middle
        flops, params_num = self._calculate_flops_and_weights_in_uniformly_pruned_model(
            right)
        if flops < target_flops:
            self.current_flops = flops
            self.current_params_num = params_num
            return right
        raise RuntimeError(
            f'Unable to prune the model to get the required '
            f'pruning level in flops = {target_flops_pruning_level}')

    def _calculate_flops_and_weights_in_uniformly_pruned_model(
            self, pruning_level):
        tmp_in_channels, tmp_out_channels = \
            calculate_in_out_channels_in_uniformly_pruned_model(
                pruning_groups=self._pruned_layer_groups_info.get_all_clusters(),
                pruning_level=pruning_level,
                full_input_channels=self._layers_in_channels,
                full_output_channels=self._layers_out_channels,
                pruning_groups_next_nodes=self._next_nodes)

        return count_flops_and_weights(
            self._original_graph,
            self._layers_in_shapes,
            self._layers_out_shapes,
            input_channels=tmp_in_channels,
            output_channels=tmp_out_channels,
            conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
            linear_op_metatypes=LINEAR_LAYER_METATYPES)

    def _calculate_filters_importance_in_group(
            self, group: Cluster[PrunedLayerInfo]):
        """
        Calculates cumulative filters importance in the group.
        :param group: Nodes cluster
        :return a list of filter importance scores
        """
        group_layers = [
            self._model.get_layer(node.layer_name) for node in group.elements
        ]
        group_filters_num = tf.constant(
            [get_filters_num(layer) for layer in group_layers])
        filters_num = group_filters_num[0]
        assert tf.reduce_all(group_filters_num == filters_num)

        cumulative_filters_importance = tf.zeros(filters_num)
        # Calculate cumulative importance for all filters in this group
        shared_nodes = set()  # type: Set[str]
        for minfo in group.elements:
            layer_name = minfo.layer_name
            if layer_name in shared_nodes:
                continue
            nncf_node = self._original_graph.get_node_by_id(minfo.nncf_node_id)
            if nncf_node.is_shared():
                shared_nodes.add(layer_name)
            filters_importance = self._layer_filter_importance(
                self._model.get_layer(layer_name))
            cumulative_filters_importance += filters_importance

        return cumulative_filters_importance

    def _collect_pruning_masks(self) -> Dict[str, TFNNCFTensor]:
        retval = {}
        for group in self._pruned_layer_groups_info.get_all_clusters():
            for node in group.elements:
                retval[node.node_name] = self._original_graph.get_node_by_name(
                    node.node_name).data['output_mask']
        return retval

    def _update_benchmark_statistics(self):
        tmp_in_channels, tmp_out_channels = calculate_in_out_channels_by_masks(
            pruning_groups=self._pruned_layer_groups_info.get_all_clusters(),
            masks=self._collect_pruning_masks(),
            tensor_processor=TFNNCFCollectorTensorProcessor,
            full_input_channels=self._layers_in_channels,
            full_output_channels=self._layers_out_channels,
            pruning_groups_next_nodes=self._next_nodes)

        self.current_filters_num = count_filters_num(
            self._original_graph,
            op_metatypes=GENERAL_CONV_LAYER_METATYPES,
            output_channels=tmp_out_channels)

        self.current_flops, self.current_params_num = \
            count_flops_and_weights(self._original_graph,
                                    self._layers_in_shapes,
                                    self._layers_out_shapes,
                                    input_channels=tmp_in_channels,
                                    output_channels=tmp_out_channels,
                                    conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
                                    linear_op_metatypes=LINEAR_LAYER_METATYPES)

    def _layer_filter_importance(self, layer: NNCFWrapper):
        layer_metatype = get_keras_layer_metatype(layer)
        if len(layer_metatype.weight_definitions) != 1:
            raise RuntimeError(
                f'The layer {layer.layer.name} does not support by the pruning '
                f'algorithm because it contains several weight attributes.')
        weight_attr = layer_metatype.weight_definitions[0].weight_attr_name
        weight = layer.layer_weights[weight_attr]
        if self.all_weights:
            weight = self._weights_normalizer(weight)
        target_weight_dim_for_compression = get_filter_axis(layer, weight_attr)
        filters_importance = self._filter_importance(
            weight, target_weight_dim_for_compression)
        return filters_importance

    def _run_batchnorm_adaptation(self):
        if self._bn_adaptation is None:
            self._bn_adaptation = BatchnormAdaptationAlgorithm(
                **extract_bn_adaptation_init_params(self.config,
                                                    'filter_pruning'))
        self._bn_adaptation.run(self.model)
Ejemplo n.º 10
0
 def _run_batchnorm_adaptation(self):
     if self._bn_adaptation is None:
         self._bn_adaptation = BatchnormAdaptationAlgorithm(
             **extract_bn_adaptation_init_params(self.qctrl.config,
                                                 "quantization"))
     self._bn_adaptation.run(self.qctrl.model)
Ejemplo n.º 11
0
class QuantizationEnv:
    # pylint:disable=too-many-branches,too-many-statements
    def __init__(self, model: NNCFNetwork,
                 quantization_controller: ExperimentalQuantizationController,
                 hw_precision_constraints: HardwareQuantizationConstraints,
                 eval_loader: torch.utils.data.DataLoader,
                 eval_fn: Callable[[nn.Module, torch.utils.data.DataLoader],
                                   float], hw_config_type: HWConfigType,
                 params: QuantizationEnvParams):

        logger.info("[Q.Env] Instantiating NNCF Quantization Environment")
        self.qctrl = quantization_controller
        self.qmodel = model
        self.eval_loader = eval_loader
        self.eval_fn = eval_fn
        self._hw_precision_constraints = hw_precision_constraints
        self._bn_adaptation = None

        self.model_name = self.qmodel.nncf_module.__class__.__name__

        # Check and only proceed if target device is supported by Q.Env
        self.hw_cfg_type = hw_config_type
        assert self.hw_cfg_type in [None, HWConfigType.VPU]

        # Set target compression ratio
        self.compression_ratio = params.compression_ratio

        self.eval_loader = PartialDataLoader(
            self.eval_loader, iter_ratio=params.eval_subset_ratio)

        # Bool to disable hard resource constraint
        self.skip_constraint = params.skip_constraint

        # Bool to enable bw alignment of adj. Q group to lower precision
        self.performant_bw = params.performant_bw

        # Bool to enable fine-tuning in each episode. Placeholder for now
        self.finetune = False

        # Counter for number of evaluate_strategy calls
        self._n_eval = 0

        # Configure search space for precision according to target device
        if self.hw_cfg_type is None:
            self.model_bitwidth_space = params.bits
        elif self.hw_cfg_type is HWConfigType.VPU:
            self.model_bitwidth_space = self._hw_precision_constraints.get_all_unique_bitwidths(
            )
        self.model_bitwidth_space = sorted(list(self.model_bitwidth_space))

        # Create mapping of QuantizerId to the space of the corresponding quantizer's allowed qconfigs
        #pylint:disable=line-too-long
        self.qconfig_space_map = OrderedDict.fromkeys(
            self.qctrl.all_quantizations.keys(
            ))  # type: Dict[QuantizerId, List[QuantizerConfig]]
        if self.hw_cfg_type is None:
            for qid in self.qconfig_space_map.keys():
                conf = self.qctrl.all_quantizations[qid].get_quantizer_config()
                conf_list_to_set = []
                for bit in self.model_bitwidth_space:
                    bit_adjusted_conf = deepcopy(conf)
                    bit_adjusted_conf.num_bits = bit
                    conf_list_to_set.append(bit_adjusted_conf)
                self.qconfig_space_map[qid] = conf_list_to_set
        else:
            for qid in self.qconfig_space_map:
                conf_list_to_set = []
                bw_vs_qconfigs_dict = self._hw_precision_constraints.get_bitwidth_vs_qconfigs_dict(
                    qid)
                for bitwidth, qconf_list in bw_vs_qconfigs_dict.items():
                    target_qconf = qconf_list[0]
                    if len(qconf_list) > 1:
                        logger.warning(
                            "Received multiple quantizer configurations {qc_lst} for same bitwidth {bw} "
                            "for quantizer {q} - AutoQ can currently only choose among bitwidths, but not "
                            "within quantizer configuration space with the same bitwidths. Selecting {qc} "
                            "as the target configuration for bitwidth {bw}".
                            format(qc_lst=";".join(
                                [str(qconf) for qconf in qconf_list]),
                                   bw=bitwidth,
                                   q=str(qid),
                                   qc=str(target_qconf)))
                    conf_list_to_set.append(target_qconf)

                self.qconfig_space_map[qid] = conf_list_to_set

        # Quantizer Master Table Creation
        self.groups_of_adjacent_quantizers = self.qctrl._groups_of_adjacent_quantizers
        self.quantizer_table = self._create_quantizer_table()

        # Create master dataframe to keep track of quantizable layers and their attributes
        self.master_df, self.state_list = self._get_state_space(
            self.qctrl, self.qmodel, self.quantizer_table)
        if self.master_df.isnull().values.any():
            raise ValueError("Q.Env Master Dataframe has null value(s)")

        assert len(self.quantizer_table) == len(self.qctrl.all_quantizations), \
            "Number of Quantizer is not tally between quantizer table and quantization controller"

        # MinMaxScaler for State Embedding
        self.state_scaler = MinMaxScaler()
        self.state_scaler.fit(self.master_df[self.state_list])

        # Mapping required for quantizer BW alignment flow
        self.adjq_groupwise_intersecting_bw_space = self._create_map_of_adjq_groupid_to_common_bw_space(
        )
        self.adjq_groupwise_df_lut_keys = self._create_map_of_adjq_groupid_to_df_lut_keys(
        )

        # Model Size Calculation
        self.model_size_calculator = ModelSizeCalculator(
            self.qmodel, self.qconfig_space_map)
        self.orig_model_size = self.model_size_calculator.fp_model_size
        self.min_model_size = self.model_size_calculator.min_model_size
        self.max_model_size = self.model_size_calculator.max_model_size
        self.target_model_size = self.orig_model_size * self.compression_ratio

        if self.target_model_size < self.min_model_size and self.target_model_size > self.max_model_size:
            raise ValueError(
                "Model Size Ratio {} is out of bound ({}, {})".format(
                    self.compression_ratio,
                    self.min_model_size / self.orig_model_size,
                    self.max_model_size / self.orig_model_size))

        # Compression Ratio Calculation (BOP relative to 8-bit)
        self.compression_ratio_calculator = CompressionRatioCalculator(
            self.qmodel.get_flops_per_module(),
            self.qctrl.get_quantizer_setup_for_current_state(), self.qctrl.
            groups_of_adjacent_quantizers.weight_qp_id_per_activation_qp_id)

        # Evaluate and store metric score of pretrained model
        self._evaluate_pretrained_model()
        self.qmodel_init_sd = deepcopy(self.qmodel.state_dict())

        self.reset()

        self._dump_autoq_data = params.dump_init_precision_data
        if self._dump_autoq_data or is_debug():
            dump_dir = params.log_dir
            if dump_dir is None:
                dump_dir = DEBUG_LOG_DIR
            self.dump_dir = Path(dump_dir) / Path("autoq_env_dump")
            self.dump_dir.mkdir(parents=True, exist_ok=True)
            # Serialize Q.Env information. Note that these functions should be at the end of Q.Env Initialization.
            self._dump_master_df()
            self._dump_quantized_graph()
            self._dump_groups_of_adjacent_quantizers()

        # End of QuantizationEnv.__init__()
        # --------------------------------------------------------------------------------------------------------------

    def reset(self):
        self.collected_strategy = []
        self.master_df['action'] = max(self.model_bitwidth_space)
        self.master_df['prev_action'] = 0
        self.master_df['unconstrained_action'] = 0

    def _create_quantizer_table(self) -> pd.DataFrame:
        def get_hook(qid, exec_order_list):
            def register_quantizer_exec_order(module, input_, output, qid,
                                              exec_order_list):
                exec_order_list.append(qid)

            return functools.partial(register_quantizer_exec_order,
                                     qid=qid,
                                     exec_order_list=exec_order_list)

        # Create a mapping of qid to its adjacent quantizer group id
        adjq_gid_map = OrderedDict.fromkeys(
            self.qctrl.all_quantizations.keys())
        for qid in self.qctrl.all_quantizations:
            adjq_gid_map[
                qid] = self.groups_of_adjacent_quantizers.get_group_id_for_quantizer(
                    qid)

        assert len(set(self.qconfig_space_map.keys()) - set(adjq_gid_map.keys())) == 0, \
            "both qconfig_space_map and adjq_gid_map must have exact keys."

        # By design, AutoQ requires quantizers in execution order.
        # RL assumes that state satisfies Markov assumption in which
        # the future is independent of the past given current state.
        # Stated differently, curret state should represent well of historical dynamics.
        # Given sequential nature of NN, state transition in the order of
        # quantizer being executed is a natural design to conform the assumption.
        quantizers_in_exec_order = []
        hooklist = []
        for qid, qmod in self.qctrl.all_quantizations.items():
            hooklist.append(
                qmod.register_forward_hook(
                    get_hook(qid, quantizers_in_exec_order)))
        self.qmodel.do_dummy_forward(force_eval=True)
        for h in hooklist:
            h.remove()

        d = OrderedDict()
        for qid in quantizers_in_exec_order:
            idx_str = str(qid)
            gid = adjq_gid_map[qid]

            d[idx_str] = OrderedDict()
            d[idx_str]['qid'] = str(qid)
            d[idx_str]['gid'] = gid
            d[idx_str]['qconf_space'] = self.qconfig_space_map[qid]
            d[idx_str][
                'qp_id_set'] = self.qctrl.module_id_to_qp_id_translation_dict[
                    qid]
            d[idx_str]['state_scope'] = qid.target_node_name

        # quantizer_table index is QuantizerId in string prepended with its quantize node id in NNCFGraph
        df = pd.DataFrame.from_dict(d, orient='index')
        df['qid_obj'] = df['qid'].apply(
            lambda x: find_qid_by_str(self.qctrl, x))
        df['qmodule'] = df['qid_obj'].apply(
            lambda x: self.qctrl.all_quantizations[x])
        df['is_wt_quantizer'] = df['qid_obj'].apply(
            lambda x: x in self.qctrl.weight_quantizers)
        df['state_module'] = df['state_scope'].apply(
            self.qmodel.get_containing_module)

        return df

    def _get_state_space(
            self, quantization_controller: QuantizationController,
            quantized_model: NNCFNetwork,
            quantizer_table: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
        def annotate_learnable_module_io_shape(model):
            def annotate_io_shape(module, input_, output):
                if hasattr(module, 'weight') or isinstance(
                        module, BaseQuantizer):
                    module.input_shape_ = input_[0].shape
                    module.output_shape_ = output.shape

            hook_list = [
                m.register_forward_hook(annotate_io_shape)
                for n, m in model.named_modules()
            ]
            model.do_dummy_forward(force_eval=True)
            for h in hook_list:
                h.remove()

        annotate_learnable_module_io_shape(quantized_model)

        # State Embedding Extraction
        #---------------------------
        df = quantizer_table
        layer_attr_df = df.apply(self._get_layer_attr, axis=1)
        layer_attr_df['layer_idx'] = np.array(range(len(layer_attr_df)))
        layer_attr_df['weight_quantizer'] = df['is_wt_quantizer'].astype(
            'float')
        state_list = layer_attr_df.columns.to_list()

        # create master dataframe
        master_df = pd.concat([df, layer_attr_df], axis='columns')

        # Annotate a min and a max value in prev_action before minmaxscaler fitting
        master_df.loc[master_df.index[0],
                      'prev_action'] = max(self.model_bitwidth_space)
        master_df.loc[master_df.index[-1],
                      'prev_action'] = min(self.model_bitwidth_space)

        # add GEMM Ops to weight quantizer
        master_df['n_op'] = master_df['state_scope'].map(
            self.qmodel.get_flops_per_module())
        master_df['n_op'] = master_df['n_op'].fillna(0)

        return master_df, state_list

    def _get_layer_attr(self, row: pd.Series) -> pd.Series:
        m = row.state_module
        qid = row.qid_obj
        feature = OrderedDict()

        if isinstance(qid, WeightQuantizerId):
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                feature['conv_dw'] = int(
                    m.weight.shape[1] ==
                    m.groups)  # 1.0 for depthwise, 0.0 for other conv2d
                feature['cin'] = m.weight.shape[1]
                feature['cout'] = m.weight.shape[0]
                feature['stride'] = m.stride[0]
                feature['kernel'] = m.kernel_size[0]
                feature['param'] = np.prod(m.weight.size())
                feature['ifm_size'] = np.prod(m.input_shape_[-2:])  # H*W
                feature['prev_action'] = 0.0  # placeholder

            elif isinstance(m, nn.Linear):
                feature['conv_dw'] = 0.0
                feature['cin'] = m.in_features
                feature['cout'] = m.out_features
                feature['stride'] = 0.0
                feature['kernel'] = 1.0
                feature['param'] = np.prod(m.weight.size())
                feature['ifm_size'] = np.prod(
                    m.input_shape_[-1])  # feature elements
                feature['prev_action'] = 0.0  # placeholder

            else:
                raise NotImplementedError(
                    "State embedding extraction of {}".format(
                        m.__class__.__name__))

        elif isinstance(qid, NonWeightQuantizerId):
            qmod = self.qctrl.all_quantizations[qid]
            input_shape = qmod.input_shape_
            output_shape = qmod.output_shape_
            feature['cin'] = input_shape[1] if len(
                input_shape) == 4 else input_shape[-1]
            feature['cout'] = output_shape[1] if len(
                output_shape) == 4 else output_shape[-1]
            feature['ifm_size'] = np.prod(
                input_shape[-2:]) if len(input_shape) == 4 else input_shape[-1]
            feature['conv_dw'] = 0.0
            feature['stride'] = 0.0
            feature['kernel'] = 0.0
            feature['param'] = 0.0
            feature['prev_action'] = 0.0

            if len(input_shape) != 4 and len(input_shape) != 2:
                raise NotImplementedError(
                    "A design is required to cater this scenario. Pls. report to maintainer"
                )
        else:
            raise ValueError(
                "qid is an instance of unexpected class {}".format(
                    qid.__class__.__name__))

        return pd.Series(feature)

    def _create_map_of_adjq_groupid_to_common_bw_space(self) -> Dict:
        # Extracting common bitwidth space per group of quantizer
        bwassigner_df = deepcopy(self.master_df)
        bwassigner_df['bw_space'] = list(
            map(lambda x: [qc.num_bits for qc in x],
                bwassigner_df.qconf_space.values))

        adjq_groupwise_intersecting_bw_space = {}
        for i, _ in enumerate(self.groups_of_adjacent_quantizers):
            list_of_bw_space = []

            for aq in self.groups_of_adjacent_quantizers[
                    i].activation_quantizers:
                bw_space = bwassigner_df.bw_space[bwassigner_df.qid == str(
                    aq[0])][0]
                list_of_bw_space.append(bw_space)

            for wq in self.groups_of_adjacent_quantizers[i].weight_quantizers:
                bw_space = bwassigner_df.bw_space[bwassigner_df.qid == str(
                    wq[0])][0]
                list_of_bw_space.append(bw_space)

            intersecting_bw_space = set.intersection(
                *map(set, list_of_bw_space))
            adjq_groupwise_intersecting_bw_space[i] = intersecting_bw_space

        return adjq_groupwise_intersecting_bw_space

    def _create_map_of_adjq_groupid_to_df_lut_keys(self) -> Dict:
        adjq_groupwise_df_lut_keys = {}

        for i, _ in enumerate(self.groups_of_adjacent_quantizers):
            group_members = []
            for _, aq in enumerate(self.groups_of_adjacent_quantizers[i].
                                   activation_quantizers):
                group_members.append(
                    self.master_df.index[self.master_df.qid == str(aq[0])][0])
            for _, wq in enumerate(
                    self.groups_of_adjacent_quantizers[i].weight_quantizers):
                group_members.append(
                    self.master_df.index[self.master_df.qid == str(wq[0])][0])
            adjq_groupwise_df_lut_keys[i] = natsorted(group_members)

        return adjq_groupwise_df_lut_keys

    def _evaluate_pretrained_model(self):
        logger.info("[Q.Env] Evaluating Pretrained Model")
        self.qctrl.disable_weight_quantization()
        self.qctrl.disable_activation_quantization()

        with torch.no_grad():
            self.pretrained_score = self.eval_fn(self.qmodel, self.eval_loader)
            logger.info("Pretrained Score: {:.3f}".format(
                self.pretrained_score))

        self.qctrl.enable_weight_quantization()
        self.qctrl.enable_activation_quantization()
        self.qmodel.rebuild_graph()

    def _run_batchnorm_adaptation(self):
        if self._bn_adaptation is None:
            self._bn_adaptation = BatchnormAdaptationAlgorithm(
                **extract_bn_adaptation_init_params(self.qctrl.config,
                                                    "quantization"))
        self._bn_adaptation.run(self.qctrl.model)

    def _run_quantization_pipeline(self, finetune=False) -> float:
        if self.qctrl.config:
            self._run_batchnorm_adaptation()

        if finetune:
            raise NotImplementedError(
                "Post-Quantization fine tuning is not implemented.")
        with torch.no_grad():
            quantized_score = self.eval_fn(self.qmodel, self.eval_loader)
            logger.info(
                "[Q.Env] Quantized Score: {:.3f}".format(quantized_score))
        return quantized_score

    def _get_quantizer_bitwidth(self) -> Dict[BaseQuantizer, int]:
        assert len(set(self.model_bitwidth_space) - set(self.master_df.action.values)) >= 0, \
            "there is bitwidth choice not within model bitwidth space"
        return OrderedDict(zip(self.master_df.qid_obj, self.master_df.action))

    def _constrain_model_size(self,
                              collected_strategy: List,
                              skip=False) -> List:
        def lower_bitwidth(bw: int, qconf_space: List[QuantizerConfig]) -> int:
            bw_space = [qconf.num_bits for qconf in qconf_space]
            assert bw in bw_space
            sorted_bw_space = sorted(bw_space)
            return sorted_bw_space[sorted_bw_space.index(bw) -
                                   1] if sorted_bw_space.index(bw) > 0 else bw

        # This function acts on self.master_df['action']
        self.master_df['action'] = collected_strategy

        if skip is not True:
            self.master_df['unconstrained_action'] = self.master_df['action']

            current_model_size = self.model_size_calculator(
                self._get_quantizer_bitwidth())

            while self.min_model_size < current_model_size and self.target_model_size < current_model_size:
                for _, nodestr in enumerate(
                        reversed(self.master_df.index.tolist())):
                    if self.master_df.loc[nodestr, "is_wt_quantizer"]:
                        bw_choice, qconf_space = self.master_df.loc[
                            nodestr, ['action', 'qconf_space']]
                        new_bw = lower_bitwidth(bw_choice, qconf_space)
                        self.master_df.loc[
                            nodestr,
                            "action"] = new_bw if new_bw != bw_choice else bw_choice

                    current_model_size = self.model_size_calculator(
                        self._get_quantizer_bitwidth())
                    if current_model_size <= self.target_model_size:
                        break
        else:
            logger.info("[Q.Env] Skipping Model Size Constraint")

        return self.master_df['action'].tolist()

    def reward(self, acc: float, model_ratio: float) -> float:
        def order_of_magnitude(number):
            return np.floor(np.math.log(abs(number), 10))

        if self.pretrained_score == 0:
            return acc
        order = order_of_magnitude(self.pretrained_score)
        return acc * (10**(-order))

    def step(self, action: Union[int, float]) -> Tuple:
        currently_processed_qconf_idx = len(self.collected_strategy)

        def is_final_step():
            return len(self.collected_strategy) == len(self.master_df)

        # Ensure action is in the quantizer's bitwidth space
        current_qconf_space = self.master_df.qconf_space[
            currently_processed_qconf_idx]
        current_bw_space = [qconf.num_bits for qconf in current_qconf_space]
        if action not in current_bw_space:
            closest_bw_idx = np.argmin(
                np.abs(action - np.array(current_bw_space)))
            action = current_bw_space[closest_bw_idx]

        self.collected_strategy.append(action)

        if not is_final_step():
            info_set = {}
            reward = 0
            self.set_next_step_prev_action(len(self.collected_strategy),
                                           action)
            obs = self.get_normalized_obs(len(self.collected_strategy))
            done = False
            return obs, reward, done, info_set

        return self.evaluate_strategy(self.collected_strategy,
                                      skip_constraint=self.skip_constraint)

    def select_config_for_actions(
            self, actions) -> Dict[QuantizationPointId, QuantizerConfig]:
        retval = OrderedDict(
        )  # type: Dict[QuantizationPointId, QuantizerConfig]
        for action, qp_id_set, qconf_space in zip(
                actions, self.master_df['qp_id_set'],
                self.master_df['qconf_space']):
            matches = []
            for qconf in qconf_space:
                if qconf.num_bits == action:
                    matches.append(qconf)
            assert len(matches) == 1
            for qp_id in qp_id_set:
                retval[qp_id] = matches[0]
        return retval

    def evaluate_strategy(self,
                          collected_strategy: List,
                          skip_constraint=True) -> Tuple:
        assert len(collected_strategy) == len(self.master_df)
        if skip_constraint is not True:
            collected_strategy = self._constrain_model_size(collected_strategy)
        self.master_df[
            'action'] = collected_strategy  # This must be after constraint

        if self.performant_bw:
            self._align_bw_action()
            configs_to_set = self.select_config_for_actions(
                self.master_df['action_aligned'])

            if self._dump_autoq_data or is_debug():
                self._dump_adjacent_quantizer_group_alignment()

            self.master_df['action'] = self.master_df['action_aligned']
        else:
            configs_to_set = self.select_config_for_actions(
                self.master_df['action'])

        self._apply_quantizer_configs_to_model(configs_to_set)

        for idx, qid in zip(self.master_df.index, self.master_df['qid']):
            logger.info("[Q.Env] {:50} | {}".format(
                str(self.qctrl.all_quantizations[find_qid_by_str(
                    self.qctrl, qid)]), idx))

        quantized_score = self._run_quantization_pipeline(
            finetune=self.finetune)

        current_model_size = self.model_size_calculator(
            self._get_quantizer_bitwidth())
        current_model_ratio = self.model_size_calculator.get_model_size_ratio(
            self._get_quantizer_bitwidth())

        current_model_bop_ratio = self.compression_ratio_calculator.run_for_quantizer_setup(
            self.qctrl.get_quantizer_setup_for_current_state())

        reward = self.reward(quantized_score, current_model_ratio)

        info_set = {
            'model_ratio': current_model_ratio,
            'accuracy': quantized_score,
            'model_size': current_model_size,
            'bop_ratio': current_model_bop_ratio
        }

        obs = self.get_normalized_obs(len(collected_strategy) - 1)
        done = True
        self._n_eval += 1

        return obs, reward, done, info_set

    def set_next_step_prev_action(self, idx, action):
        self.master_df.loc[self.master_df.index[idx], 'prev_action'] = action

    def get_normalized_obs(self, idx: int) -> pd.Series:
        _df = self.master_df.loc[self.master_df.index, self.state_list]
        _df.loc[_df.index, self.state_list] = self.state_scaler.transform(
            _df[self.state_list])
        return _df.iloc[idx]

    def _apply_quantizer_configs_to_model(
            self, qid_vs_qconfig_map: Dict[QuantizationPointId,
                                           QuantizerConfig]):
        new_quantizer_setup = self.qctrl.get_quantizer_setup_for_current_state(
        )
        for qp_id, qconf in qid_vs_qconfig_map.items():
            new_quantizer_setup.quantization_points[qp_id].qconfig = qconf
        self.qmodel.load_state_dict(self.qmodel_init_sd)
        had_to_regenerate = self.qctrl.is_new_setup_requires_regeneration(
            new_quantizer_setup)
        self.qctrl, self.qmodel = self.qctrl.apply_new_quantizer_setup(
            new_quantizer_setup)
        if had_to_regenerate:
            self.qmodel_init_sd = deepcopy(self.qmodel.state_dict())

        # The QuantizerId's may have changed after the new quantizer setup application, but
        # QuantizationPointId's should not have. Will use this to update the qids in the master table.
        for qid, qp_id_set in self.qctrl.module_id_to_qp_id_translation_dict.items(
        ):
            self.master_df.loc[self.master_df.qp_id_set ==
                               qp_id_set].qid = str(qid)

    def _dump_master_df(self):
        self.master_df.drop('state_module', axis=1).to_csv(
            osp.join(self.dump_dir,
                     self.model_name + "_quantizable_state_table.csv"),
            index_label="nodestr")

    def _dump_quantized_graph(self):
        self.qmodel.get_graph().visualize_graph(
            osp.join(self.dump_dir, "qenv_graph.dot"))

    def _dump_groups_of_adjacent_quantizers(self):
        adj_quantizer_groups = []

        for i, _ in enumerate(self.groups_of_adjacent_quantizers):
            group_members = []
            for _, aq in enumerate(self.groups_of_adjacent_quantizers[i].
                                   activation_quantizers):
                group_members.append(
                    self.master_df.index[self.master_df.qid == str(aq[0])][0])
            for _, wq in enumerate(
                    self.groups_of_adjacent_quantizers[i].weight_quantizers):
                group_members.append(
                    self.master_df.index[self.master_df.qid == str(wq[0])][0])
            adj_quantizer_groups.append(natsorted(group_members))

        with safe_open(
                self.dump_dir / "{}_groups_of_adjacent_quantizers.json".format(
                    self.model_name), "w") as DUMP_FH:
            json.dump(natsorted(adj_quantizer_groups), DUMP_FH, indent=4)

    def _align_bw_action(self):
        # align bw action per group of adjacent quantizer
        # this alignment aims to realize GEMM compute in a lower precision

        self.master_df['action_aligned'] = 0

        for i, _ in enumerate(self.groups_of_adjacent_quantizers):
            # Collect all actions of a group
            list_of_action = []

            for _, aq in enumerate(self.groups_of_adjacent_quantizers[i].
                                   activation_quantizers):
                list_of_action.append(
                    self.master_df.action[self.master_df.qid == str(aq[0])][0])

            for _, wq in enumerate(
                    self.groups_of_adjacent_quantizers[i].weight_quantizers):
                list_of_action.append(
                    self.master_df.action[self.master_df.qid == str(wq[0])][0])

            # Get the minimum prediction bw of a group
            group_min_predicted_bw = min(list_of_action)

            # Access and get the intersecting bw among all quantizers in a group
            intersecting_bw_space = self.adjq_groupwise_intersecting_bw_space[
                i]

            # Determine the lowest realizable hardware precision given current action of the groups
            if not group_min_predicted_bw in intersecting_bw_space:
                group_final_min_bw = min(intersecting_bw_space)
            else:
                group_final_min_bw = group_min_predicted_bw

            # Assignment Routine
            for _, aq in enumerate(self.groups_of_adjacent_quantizers[i].
                                   activation_quantizers):
                self.master_df.loc[self.master_df.qid == str(aq[0]),
                                   "action_aligned"] = group_final_min_bw

            for _, wq in enumerate(
                    self.groups_of_adjacent_quantizers[i].weight_quantizers):
                if self.master_df.loc[self.master_df.qid == str(wq[0]),
                                      "action"][0] > group_final_min_bw:
                    self.master_df.loc[self.master_df.qid == str(wq[0]),
                                       "action_aligned"] = group_final_min_bw
                else:
                    self.master_df.loc[self.master_df.qid == str(wq[0]), "action_aligned"] = \
                        self.master_df.loc[self.master_df.qid == str(wq[0]), "action"][0]

    def _dump_adjacent_quantizer_group_alignment(self):
        list_of_dump_dict = []
        for i, _ in enumerate(self.groups_of_adjacent_quantizers):
            list_of_dump_dict.append(
                self.master_df.loc[self.adjq_groupwise_df_lut_keys[i],
                                   ["action", "action_aligned"]].to_dict())

        os.makedirs(self.dump_dir / 'bw_alignment', exist_ok=True)
        with safe_open(
                self.dump_dir /
                'bw_alignment/{0:03d}_bw_alignment.json'.format(self._n_eval),
                "w") as DUMP_FH:
            json.dump(list_of_dump_dict, DUMP_FH, indent=4)
Ejemplo n.º 12
0
class MagnitudeSparsityController(BaseSparsityAlgoController):
    def __init__(self, target_model: NNCFNetwork,
                 sparsified_module_info: List[SparseModuleInfo],
                 config: NNCFConfig):
        super().__init__(target_model, sparsified_module_info)
        self._config = config
        self._algo_config = extract_algo_specific_config(
            self._config, 'magnitude_sparsity')
        params = self._algo_config.get('params', {})

        self._weight_importance_fn = WEIGHT_IMPORTANCE_FUNCTIONS[params.get(
            'weight_importance', 'normed_abs')]
        self._mode = params.get('sparsity_level_setting_mode', 'global')
        self._scheduler = None
        sparsity_init = self._algo_config.get('sparsity_init', 0)

        if self._mode == 'global':
            scheduler_params = deepcopy(params)
            scheduler_params['sparsity_init'] = sparsity_init
            scheduler_cls = SPARSITY_SCHEDULERS.get(
                params.get('schedule', 'polynomial'))
            self._scheduler = scheduler_cls(self, scheduler_params)
        else:
            self._scheduler = StubCompressionScheduler()

        self._bn_adaptation = None

        self.set_sparsity_level(sparsity_init)

    def statistics(self,
                   quickly_collected_only: bool = False) -> NNCFStatistics:
        collector = PTSparseModelStatisticsCollector(
            self.model, self.sparsified_module_info)
        model_statistics = collector.collect()

        threshold_statistics = []
        if self._mode == 'global':
            global_threshold = self._select_threshold(
                model_statistics.sparsity_level_for_layers,
                self.sparsified_module_info)

        module_name_to_sparsity_level_map = {
            s.name: s.sparsity_level
            for s in model_statistics.sparsified_layers_summary
        }
        for minfo in self.sparsified_module_info:
            if self._mode == 'global':
                threshold = global_threshold
            else:
                sparsity_level_for_sparse_module = module_name_to_sparsity_level_map[
                    minfo.module_node_name]
                threshold = self._select_threshold(
                    sparsity_level_for_sparse_module, [minfo])

            threshold_statistics.append(
                LayerThreshold(minfo.module_node_name, threshold))

        target_sparsity_level = self.scheduler.current_sparsity_level if self._mode == 'global' else None

        stats = MagnitudeSparsityStatistics(model_statistics,
                                            threshold_statistics,
                                            target_sparsity_level)

        nncf_stats = NNCFStatistics()
        nncf_stats.register('magnitude_sparsity', stats)
        return nncf_stats

    def freeze(self, freeze: bool = True):
        for layer in self.sparsified_module_info:
            layer.operand.frozen = freeze

    @property
    def compression_rate(self):
        return self.statistics(
        ).magnitude_sparsity.model_statistics.sparsity_level

    @compression_rate.setter
    def compression_rate(self, sparsity_level: float):
        self.freeze(False)
        self.set_sparsity_level(sparsity_level)
        self.freeze(True)

    def set_sparsity_level(
            self,
            sparsity_level,
            target_sparsified_module_info: SparseModuleInfo = None,
            run_batchnorm_adaptation: bool = False):
        if sparsity_level >= 1 or sparsity_level < 0:
            raise AttributeError(
                'Sparsity level should be within interval [0,1), actual value to set is: {}'
                .format(sparsity_level))
        if target_sparsified_module_info is None:
            target_sparsified_module_info_list = self.sparsified_module_info  # List[SparseModuleInfo]
        else:
            target_sparsified_module_info_list = [
                target_sparsified_module_info
            ]
        threshold = self._select_threshold(sparsity_level,
                                           target_sparsified_module_info_list)
        self._set_masks_for_threshold(threshold,
                                      target_sparsified_module_info_list)

        if run_batchnorm_adaptation:
            self._run_batchnorm_adaptation()

    def _select_threshold(self, sparsity_level,
                          target_sparsified_module_info_list):
        all_weights = self._collect_all_weights(
            target_sparsified_module_info_list)
        if not all_weights:
            return 0.0
        all_weights_tensor, _ = torch.cat(all_weights).sort()
        threshold = all_weights_tensor[int(
            (all_weights_tensor.size(0) - 1) * sparsity_level)].item()
        return threshold

    def _set_masks_for_threshold(self, threshold_val,
                                 target_sparsified_module_info_list):
        for layer in target_sparsified_module_info_list:
            if not layer.operand.frozen:
                layer.operand.binary_mask = calc_magnitude_binary_mask(
                    layer.module.weight, self._weight_importance_fn,
                    threshold_val)

    def _collect_all_weights(
            self, target_sparsified_module_info_list: List[SparseModuleInfo]):
        all_weights = []
        for minfo in target_sparsified_module_info_list:
            all_weights.append(
                self._weight_importance_fn(minfo.module.weight).view(-1))
        return all_weights

    def compression_stage(self) -> CompressionStage:
        if self._mode == 'local':
            return CompressionStage.FULLY_COMPRESSED

        if self.scheduler.current_sparsity_level == 0:
            return CompressionStage.UNCOMPRESSED
        if self.scheduler.current_sparsity_level >= self.scheduler.target_level:
            return CompressionStage.FULLY_COMPRESSED
        return CompressionStage.PARTIALLY_COMPRESSED

    def _run_batchnorm_adaptation(self):
        if self._bn_adaptation is None:
            self._bn_adaptation = BatchnormAdaptationAlgorithm(
                **extract_bn_adaptation_init_params(self._config,
                                                    'magnitude_sparsity'))
        self._bn_adaptation.run(self.model)