Exemplo n.º 1
0
    def _sparsify_weights(
            self, target_model: NNCFNetwork) -> List[PTInsertionCommand]:
        device = next(target_model.parameters()).device
        sparsified_module_nodes = target_model.get_weighted_original_graph_nodes(
            nncf_module_names=self.compressed_nncf_module_names)
        insertion_commands = []
        for module_node in sparsified_module_nodes:
            node_name = module_node.node_name

            if not self._should_consider_scope(node_name):
                nncf_logger.info(
                    "Ignored adding Weight Sparsifier in scope: {}".format(
                        node_name))
                continue

            nncf_logger.info(
                "Adding Weight Sparsifier in scope: {}".format(node_name))
            compression_lr_multiplier = \
                self.config.get_redefinable_global_param_value_for_algo('compression_lr_multiplier',
                                                                        self.name)
            operation = self.create_weight_sparsifying_operation(
                module_node, compression_lr_multiplier)
            hook = operation.to(device)
            insertion_commands.append(
                PTInsertionCommand(
                    PTTargetPoint(TargetType.OPERATION_WITH_WEIGHTS,
                                  target_node_name=node_name), hook,
                    TransformationPriority.SPARSIFICATION_PRIORITY))
            sparsified_module = target_model.get_containing_module(node_name)
            self._sparsified_module_info.append(
                SparseModuleInfo(node_name, sparsified_module, hook))

        return insertion_commands
Exemplo n.º 2
0
 def __init__(self, target_model: NNCFNetwork,
              sparsified_module_info: List[SparseModuleInfo]):
     super().__init__(target_model)
     self._loss = ZeroCompressionLoss(
         next(target_model.parameters()).device)
     self._scheduler = BaseCompressionScheduler()
     self.sparsified_module_info = sparsified_module_info
Exemplo n.º 3
0
    def __init__(self, target_model: NNCFNetwork, config: NNCFConfig):
        super().__init__(target_model)

        self._loss = ZeroCompressionLoss(
            next(target_model.parameters()).device)
        scheduler_cls = QUANTIZATION_SCHEDULERS.get("staged")
        algo_config = extract_algo_specific_config(config, "binarization")
        self._scheduler = scheduler_cls(self, algo_config.get("params", {}))
        from nncf.torch.utils import is_main_process
        if is_main_process():
            self._compute_and_display_flops_binarization_rate()
Exemplo n.º 4
0
    def _binarize_weights_and_module_inputs(
            self, target_model: NNCFNetwork) -> List[PTInsertionCommand]:
        device = next(target_model.parameters()).device

        module_nodes = target_model.get_weighted_original_graph_nodes(
            nncf_module_names=self.compressed_nncf_module_names)

        insertion_commands = []
        for module_node in module_nodes:
            if not self._should_consider_scope(module_node.node_name):
                nncf_logger.info(
                    "Ignored adding binarizers in scope: {}".format(
                        module_node.node_name))
                continue

            nncf_logger.info("Adding Weight binarizer in scope: {}".format(
                module_node.node_name))
            op_weights = self.__create_binarize_module().to(device)

            nncf_logger.info("Adding Activation binarizer in scope: {}".format(
                module_node.node_name))

            compression_lr_multiplier = self.config.get_redefinable_global_param_value_for_algo(
                'compression_lr_multiplier', self.name)

            op_inputs = UpdateInputs(
                ActivationBinarizationScaleThreshold(
                    module_node.layer_attributes.get_weight_shape(),
                    compression_lr_multiplier=compression_lr_multiplier)).to(
                        device)

            ip_w = PTTargetPoint(TargetType.OPERATION_WITH_WEIGHTS,
                                 target_node_name=module_node.node_name)
            insertion_commands.append(
                PTInsertionCommand(
                    ip_w, op_weights,
                    TransformationPriority.QUANTIZATION_PRIORITY))

            ip_i = PTTargetPoint(TargetType.PRE_LAYER_OPERATION,
                                 target_node_name=module_node.node_name,
                                 input_port_id=0)
            insertion_commands.append(
                PTInsertionCommand(
                    ip_i, op_inputs,
                    TransformationPriority.QUANTIZATION_PRIORITY))
        return insertion_commands
Exemplo n.º 5
0
 def __init__(self, target_model: NNCFNetwork, prunable_types: List[str],
              pruned_module_groups_info: Clusterization[PrunedModuleInfo],
              config: NNCFConfig):
     super().__init__(target_model)
     self._loss = ZeroCompressionLoss(
         next(target_model.parameters()).device)
     self._prunable_types = prunable_types
     self.config = config
     self.pruning_config = extract_algo_specific_config(
         config, 'filter_pruning')
     params = self.pruning_config.get('params', {})
     self.pruned_module_groups_info = pruned_module_groups_info
     self.prune_batch_norms = params.get('prune_batch_norms', True)
     self.prune_first = params.get('prune_first_conv', False)
     self.prune_downsample_convs = params.get('prune_downsample_convs',
                                              False)
     self.prune_flops = False
     self.check_pruning_level(params)
     self._hooks = []
Exemplo n.º 6
0
    def _prune_weights(self, target_model: NNCFNetwork):
        target_model_graph = target_model.get_original_graph()
        groups_of_nodes_to_prune = self.pruning_node_selector.create_pruning_groups(
            target_model_graph)

        device = next(target_model.parameters()).device
        insertion_commands = []
        self.pruned_module_groups_info = Clusterization[PrunedModuleInfo](
            lambda x: x.node_name)

        for i, group in enumerate(groups_of_nodes_to_prune.get_all_clusters()):
            group_minfos = []
            for node in group.elements:
                node_name = node.node_name
                module = target_model.get_containing_module(node_name)
                module_scope = target_model_graph.get_scope_by_node_name(
                    node_name)
                # Check that we need to prune weights in this op
                assert self._is_pruned_module(module)

                nncf_logger.info(
                    "Adding Weight Pruner in scope: {}".format(node_name))
                pruning_block = self.create_weight_pruning_operation(
                    module, node_name)
                # Hook for weights and bias
                hook = UpdateWeightAndBias(pruning_block).to(device)
                insertion_commands.append(
                    PTInsertionCommand(
                        PTTargetPoint(TargetType.PRE_LAYER_OPERATION,
                                      target_node_name=node_name), hook,
                        TransformationPriority.PRUNING_PRIORITY))
                group_minfos.append(
                    PrunedModuleInfo(
                        node_name=node_name,
                        module_scope=module_scope,
                        module=module,
                        operand=pruning_block,
                        node_id=node.node_id,
                        is_depthwise=is_prunable_depthwise_conv(node)))

            cluster = Cluster[PrunedModuleInfo](
                i, group_minfos, [n.node_id for n in group.elements])
            self.pruned_module_groups_info.add_cluster(cluster)

        # Propagate masks to find norm layers to prune
        init_output_masks_in_graph(
            target_model_graph, self.pruned_module_groups_info.get_all_nodes())
        MaskPropagationAlgorithm(
            target_model_graph, PT_PRUNING_OPERATOR_METATYPES,
            PTNNCFPruningTensorProcessor).mask_propagation()

        # Adding binary masks also for Batch/Group Norms to allow applying masks after propagation
        types_to_apply_mask = ['group_norm']
        if self.prune_batch_norms:
            types_to_apply_mask.append('batch_norm')

        all_norm_layers = target_model_graph.get_nodes_by_types(
            types_to_apply_mask)
        for node in all_norm_layers:
            if node.data['output_mask'] is None:
                # Skip elements that will not be pruned
                continue

            node_name = node.node_name
            module = target_model.get_containing_module(node_name)

            pruning_block = self.create_weight_pruning_operation(
                module, node_name)
            # Hook for weights and bias
            hook = UpdateWeightAndBias(pruning_block).to(device)
            insertion_commands.append(
                PTInsertionCommand(
                    PTTargetPoint(TargetType.PRE_LAYER_OPERATION,
                                  target_node_name=node_name), hook,
                    TransformationPriority.PRUNING_PRIORITY))
            self._pruned_norms_operators.append((node, pruning_block, module))
        return insertion_commands