示例#1
0
def test_get_first_pruned_layers(model, ref_first_module_names):
    config = get_basic_pruning_config(input_sample_size=(1, 1, 8, 8))
    config['compression']['algorithm'] = 'filter_pruning'
    pruned_model, _ = create_compressed_model_and_algo_for_test(model(), config)

    first_pruned_modules = get_first_pruned_modules(pruned_model,
                                                    FilterPruningBuilder(config).get_types_of_pruned_modules())
    ref_first_modules = [getattr(pruned_model, module_name) for module_name in ref_first_module_names]
    assert set(first_pruned_modules) == set(ref_first_modules)
示例#2
0
    def _is_module_prunable(self, target_model: NNCFNetwork, module, module_scope: Scope):
        """
        Check whether we should prune module according to algorithm parameters.
        :param target_model: model to work with
        :param module: module to check
        :param module_scope: scope for module
        :return: (prune: bool, msg: str)
        prune: Whether we should prune module
        msg: additional information why we should/shouldn't prune
        """
        prune = True
        msg = None

        pruned_types = self.get_op_types_of_pruned_modules()
        input_non_pruned_modules = get_first_pruned_modules(target_model, pruned_types + ['linear'])
        output_non_pruned_modules = get_last_pruned_modules(target_model, pruned_types + ['linear'])
        module_scope_str = str(module_scope)

        if self.ignore_frozen_layers and not module.weight.requires_grad:
            msg = "Ignored adding Weight Pruner in scope: {} because"\
                    " the layer appears to be frozen (requires_grad=False)".format(module_scope_str)
            prune = False
        elif not self._should_consider_scope(module_scope_str):
            msg = "Ignored adding Weight Pruner in scope: {}".format(module_scope_str)
            prune = False
        elif not self.prune_first and module in input_non_pruned_modules:
            msg = "Ignored adding Weight Pruner in scope: {} because"\
                             " this scope is one of the first convolutions".format(module_scope_str)
            prune = False
        elif not self.prune_last and module in output_non_pruned_modules:
            msg = "Ignored adding Weight Pruner in scope: {} because"\
                             " this scope is one of the last convolutions".format(module_scope_str)
            prune = False
        elif is_grouped_conv(module):
            if not is_depthwise_conv(module):
                msg = "Ignored adding Weight Pruner in scope: {} because" \
                      " this scope is grouped convolution".format(module_scope_str)
                prune = False
        elif not self.prune_downsample_convs and is_conv_with_downsampling(module):
            msg = "Ignored adding Weight Pruner in scope: {} because"\
                             " this scope is convolution with downsample".format(module_scope_str)
            prune = False
        return prune, msg
示例#3
0
    def _prune_weights(self, target_model: NNCFNetwork):
        device = next(target_model.parameters()).device
        modules_to_prune = target_model.get_nncf_modules()
        insertion_commands = []
        bn_for_depthwise = {}

        input_non_pruned_modules = get_first_pruned_modules(
            target_model,
            self.get_types_of_pruned_modules() + ['linear'])
        output_non_pruned_modules = get_last_pruned_modules(
            target_model,
            self.get_types_of_pruned_modules() + ['linear'])

        for module_scope, module in modules_to_prune.items():
            # Check that we need to prune weights in this op
            if not self._is_pruned_module(module):
                continue

            module_scope_str = str(module_scope)
            if self.ignore_frozen_layers and not module.weight.requires_grad:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " the layer appears to be frozen (requires_grad=False)".
                    format(module_scope_str))
                continue

            if not self._should_consider_scope(module_scope_str):
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {}".format(
                        module_scope_str))
                continue

            if not self.prune_first and module in input_non_pruned_modules:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is one of the first convolutions".format(
                        module_scope_str))
                continue
            if not self.prune_last and module in output_non_pruned_modules:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is one of the last convolutions".format(
                        module_scope_str))
                continue

            if is_grouped_conv(module):
                if is_depthwise_conv(module):
                    previous_conv = get_previous_conv(target_model, module,
                                                      module_scope)
                    if previous_conv:
                        depthwise_bn = get_bn_for_module_scope(
                            target_model, module_scope)
                        bn_for_depthwise[str(previous_conv.op_exec_context.
                                             scope_in_model)] = depthwise_bn

                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is grouped convolution".format(
                        module_scope_str))
                continue

            if not self.prune_downsample_convs and is_conv_with_downsampling(
                    module):
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is convolution with downsample".format(
                        module_scope_str))
                continue

            nncf_logger.info(
                "Adding Weight Pruner in scope: {}".format(module_scope_str))
            operation = self.create_weight_pruning_operation(module)
            hook = UpdateWeight(operation).to(device)
            insertion_commands.append(
                InsertionCommand(
                    InsertionPoint(
                        InputAgnosticOperationExecutionContext(
                            "", module_scope, 0),
                        InsertionType.NNCF_MODULE_PRE_OP), hook,
                    OperationPriority.PRUNING_PRIORITY))

            related_modules = {}
            if self.prune_batch_norms:
                related_modules[
                    PrunedModuleInfo.BN_MODULE_NAME] = get_bn_for_module_scope(
                        target_model, module_scope)

            self._pruned_module_info.append(
                PrunedModuleInfo(module_scope_str, module, hook.operand,
                                 related_modules))

        if self.prune_batch_norms:
            self.update_minfo_with_depthwise_bn(bn_for_depthwise)

        return insertion_commands
    def _prune_weights(self, target_model: NNCFNetwork):
        device = next(target_model.parameters()).device
        modules_to_prune = target_model.get_nncf_modules()
        insertion_commands = []

        input_non_pruned_modules = get_first_pruned_modules(
            target_model,
            self.get_types_of_pruned_modules() + ['linear'])
        output_non_pruned_modules = get_last_pruned_modules(
            target_model,
            self.get_types_of_pruned_modules() + ['linear'])

        for module_scope, module in modules_to_prune.items():
            # Check that we need to prune weights in this op
            if not self._is_pruned_module(module):
                continue

            module_scope_str = str(module_scope)
            if not self._should_consider_scope(module_scope_str):
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {}".format(
                        module_scope_str))
                continue

            if not self.prune_first and module in input_non_pruned_modules:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is one of the first convolutions".format(
                        module_scope_str))
                continue
            if not self.prune_last and module in output_non_pruned_modules:
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is one of the last convolutions".format(
                        module_scope_str))
                continue

            if not self.prune_downsample_convs and is_conv_with_downsampling(
                    module):
                nncf_logger.info(
                    "Ignored adding Weight Pruner in scope: {} because"
                    " this scope is convolution with downsample".format(
                        module_scope_str))
                continue

            nncf_logger.info(
                "Adding Weight Pruner in scope: {}".format(module_scope_str))
            operation = self.create_weight_pruning_operation(module)
            hook = UpdateWeight(operation).to(device)
            insertion_commands.append(
                InsertionCommand(
                    InsertionPoint(
                        InputAgnosticOperationExecutionContext(
                            "", module_scope, 0),
                        InsertionType.NNCF_MODULE_PRE_OP), hook,
                    OperationPriority.PRUNING_PRIORITY))

            related_modules = {}
            if self.prune_batch_norms:
                related_modules['bn_module'] = get_bn_for_module_scope(
                    target_model, module_scope)

            self._pruned_module_info.append(
                PrunedModuleInfo(module_scope_str, module, hook.operand,
                                 related_modules))

        return insertion_commands