def load_state(model: torch.nn.Module, state_dict_to_load: dict, is_resume: bool = False,
               keys_to_ignore: List[str] = None) -> int:
    """
    Used to load a checkpoint containing a compressed model into an NNCFNetwork object, but can
    be used for any PyTorch module as well. Will do matching of state_dict_to_load parameters to
    the model's state_dict parameters while discarding irrelevant prefixes added during wrapping
    in NNCFNetwork or DataParallel/DistributedDataParallel objects, and load the matched parameters
    from the state_dict_to_load into the model's state dict.
    :param model: The target module for the state_dict_to_load to be loaded to.
    :param state_dict_to_load: A state dict containing the parameters to be loaded into the model.
    :param is_resume: Determines the behavior when the function cannot do a successful parameter match
    when loading. If True, the function will raise an exception if it cannot match the state_dict_to_load
    parameters to the model's parameters (i.e. if some parameters required by model are missing in
    state_dict_to_load, or if state_dict_to_load has parameters that could not be matched to model parameters,
    or if the shape of parameters is not matching). If False, the exception won't be raised.
    Usually is_resume is specified as False when loading uncompressed model's weights into the model with
    compression algorithms already applied, and as True when loading a compressed model's weights into the model
    with compression algorithms applied to evaluate the model.
    :param keys_to_ignore: A list of parameter names that should be skipped from matching process.
    :return: The number of state_dict_to_load entries successfully matched and loaded into model.
    """

    model_state_dict = model.state_dict()

    maybe_convert_legacy_names_in_model_state(state_dict_to_load)
    key_matcher = KeyMatcher(is_resume, state_dict_to_load, model_state_dict, keys_to_ignore)
    new_dict = key_matcher.run()
    num_loaded_params = len(new_dict)
    key_matcher.handle_problematic_keys()
    nncf_logger.info("Loaded {}/{} parameters".format(num_loaded_params, len(model_state_dict.items())))

    model.load_state_dict(new_dict, strict=False)
    return num_loaded_params
Beispiel #2
0
    def _filter_groups(self, pruned_nodes_clusterization: Clusterization,
                       can_prune: Dict[int, PruningAnalysisDecision]) -> None:
        """
        Check whether all nodes in group can be pruned based on user-defined constraints and
        connections inside the network. Otherwise the whole group cannot be pruned and will be
        removed the clusterization.

        :param pruned_nodes_clusterization: Clusterization with pruning nodes groups.
        :param can_prune: Can this node be pruned or not.
        """
        for cluster in pruned_nodes_clusterization.get_all_clusters():
            nodes_decisions = [
                can_prune[node.node_id] for node in cluster.elements
            ]
            nodes_names = [node.node_name for node in cluster.elements]
            if not all(nodes_decisions):
                cannot_prune_messages = []
                for name, decision in zip(nodes_names, nodes_decisions):
                    if not decision:
                        message = PruningAnalysisReason.message(name, decision)
                        if message:
                            cannot_prune_messages.append(message)

                nncf_logger.info(
                    'Group of nodes [{}] can\'t be pruned, because some nodes should\'t be pruned, '
                    'error messages for this nodes: {}.'.format(
                        ', '.join(nodes_names),
                        ', '.join(cannot_prune_messages)))
                pruned_nodes_clusterization.delete_cluster(cluster.id)
            else:
                nncf_logger.info(
                    'Group of nodes [{}] will be pruned together.'.format(
                        ", ".join(nodes_names)))
Beispiel #3
0
 def dump_metric_MB(self, metric_per_qconfig_sequence: List[Tensor]):
     import matplotlib.pyplot as plt
     list_to_plot = [cm.item() for cm in metric_per_qconfig_sequence]
     fig = plt.figure()
     fig.suptitle('Pareto Frontier')
     ax = fig.add_subplot(2, 1, 1)
     ax.set_yscale('log')
     ax.set_xlabel('Model Size (MB)')
     ax.set_ylabel('Metric value (total perturbation)')
     ax.scatter(self._model_sizes,
                list_to_plot,
                s=20,
                facecolors='none',
                edgecolors='r')
     cm = torch.Tensor(metric_per_qconfig_sequence)
     cm_m = cm.median().item()
     qconfig_index = metric_per_qconfig_sequence.index(cm_m)
     ms_m = self._model_sizes[qconfig_index]
     ax.scatter(ms_m,
                cm_m,
                s=30,
                facecolors='none',
                edgecolors='b',
                label='median from all metrics')
     ax.legend()
     plt.savefig(os.path.join(self._dump_dir, 'Pareto_Frontier'))
     nncf_logger.info(
         'Distribution of HAWQ metrics: min_value={:.3f}, max_value={:.3f}, median_value={:.3f}, '
         'median_index={}, total_number={}'.format(
             cm.min().item(),
             cm.max().item(), cm_m, qconfig_index,
             len(metric_per_qconfig_sequence)))
Beispiel #4
0
    def set_pruning_level(self,
                          pruning_level: float,
                          run_batchnorm_adaptation: bool = False):
        """
        Setup pruning masks in accordance to provided pruning rate
        :param pruning_level: pruning ration
        :return:
        """
        # Pruning rate from scheduler can be percentage of params that should be pruned
        self.pruning_rate = pruning_level
        if not self.frozen:
            nncf_logger.info(
                'Computing filter importance scores and binary masks...')
            if self.all_weights:
                if self.prune_flops:
                    self._set_binary_masks_for_pruned_modules_globally_by_flops_target(
                        pruning_level)
                else:
                    self._set_binary_masks_for_pruned_layers_globally(
                        pruning_level)
            else:
                if self.prune_flops:
                    # Looking for a layerwise pruning rate needed for the required flops pruning rate
                    pruning_level = self._find_uniform_pruning_level_for_target_flops(
                        pruning_level)
                self._set_binary_masks_for_pruned_layers_groupwise(
                    pruning_level)

        if run_batchnorm_adaptation:
            self._run_batchnorm_adaptation()
Beispiel #5
0
    def output_prune(cls, model: NNCFNetwork, node: NNCFNode,
                     graph: NNCFGraph) -> None:
        output_mask = node.data['output_mask']
        if output_mask is None:
            return

        bool_mask = torch.tensor(output_mask, dtype=torch.bool)
        new_num_channels = int(torch.sum(bool_mask))

        node_module = model.get_containing_module(node.node_name)
        old_num_clannels = int(node_module.weight.size(1))

        in_channels = node_module.weight.size(0)
        broadcasted_mask = bool_mask.repeat(in_channels).view(
            in_channels, bool_mask.size(0))
        new_weight_shape = list(node_module.weight.shape)
        new_weight_shape[1] = new_num_channels

        node_module.out_channels = new_num_channels
        node_module.weight = torch.nn.Parameter(
            node_module.weight[broadcasted_mask].view(new_weight_shape))

        if node_module.bias is not None:
            node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned ConvTranspose {} by pruning mask. Old output filters number: {}, new filters number:'
            ' {}.'.format(node.data['key'], old_num_clannels,
                          node_module.out_channels))
Beispiel #6
0
    def input_prune(cls, model: NNCFNetwork, node: NNCFNode,
                    graph: NNCFGraph) -> None:
        input_mask = node.data['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        new_num_channels = int(torch.sum(input_mask))

        is_depthwise = is_prunable_depthwise_conv(node)
        node_module = model.get_containing_module(node.node_name)
        old_num_channels = int(node_module.weight.size(1))

        if is_depthwise:
            # In depthwise case prune output channels by input mask, here only fix for new number of input channels
            node_module.groups = new_num_channels
            node_module.in_channels = new_num_channels
            old_num_channels = int(node_module.weight.size(0))
        else:
            out_channels = node_module.weight.size(0)
            broadcasted_mask = bool_mask.repeat(out_channels).view(
                out_channels, bool_mask.size(0))
            new_weight_shape = list(node_module.weight.shape)
            new_weight_shape[1] = new_num_channels

            node_module.in_channels = new_num_channels
            node_module.weight = torch.nn.Parameter(
                node_module.weight[broadcasted_mask].view(new_weight_shape))

        nncf_logger.info(
            'Pruned Convolution {} by input mask. Old input filters number: {}, new filters number:'
            ' {}.'.format(node.data['key'], old_num_channels,
                          new_num_channels))
Beispiel #7
0
    def _apply_minmax_init(self,
                           min_values,
                           max_values,
                           log_module_name: str = None):
        if torch.any(torch.eq(min_values, np.inf)) or torch.any(
                torch.eq(max_values, -np.inf)):
            raise AttributeError(
                'Statistics is not collected for {}'.format(log_module_name))
        sign = torch.any(torch.lt(min_values, 0))
        if self._signedness_to_force is not None and sign != self._signedness_to_force:
            nncf_logger.warning("Forcing signed to {} for module {}".format(
                self._signedness_to_force, log_module_name))
            sign = self._signedness_to_force
        self.signed = int(sign)

        abs_max = torch.max(torch.abs(max_values), torch.abs(min_values))
        SCALE_LOWER_THRESHOLD = 0.1
        mask = torch.gt(abs_max, SCALE_LOWER_THRESHOLD)
        self._scale_param_storage.data = torch.where(
            mask, abs_max,
            SCALE_LOWER_THRESHOLD * torch.ones_like(self._scale_param_storage))
        if self._is_using_log_scale_storage:
            self._scale_param_storage.data.log_()

        nncf_logger.info("Set sign: {} and scale: {} for {}".format(
            self.signed, get_flat_tensor_contents_string(self.scale),
            log_module_name))
Beispiel #8
0
    def get_average_traces(self, max_iter=500, tolerance=1e-5) -> Tensor:
        """
        Estimates average hessian trace for each parameter
        :param max_iter: maximum number of iterations for Hutchinson algorithm
        :param tolerance: - minimum relative tolerance for stopping the algorithm.
        It's calculated  between mean average trace from previous iteration and current one.
        :return: Tensor with average hessian trace per parameter
        """
        avg_total_trace = 0.
        avg_traces_per_iter = []  # type: List[Tensor]
        mean_avg_traces_per_param = None

        for i in range(max_iter):
            avg_traces_per_iter.append(self._calc_avg_traces_per_param())

            mean_avg_traces_per_param = self._get_mean(avg_traces_per_iter)
            mean_avg_total_trace = torch.sum(mean_avg_traces_per_param)

            diff_avg = abs(mean_avg_total_trace - avg_total_trace) / (
                avg_total_trace + self._diff_eps)
            if diff_avg < tolerance:
                return mean_avg_traces_per_param
            avg_total_trace = mean_avg_total_trace
            nncf_logger.info('{}# difference_avg={} avg_trace={}'.format(
                i, diff_avg, avg_total_trace))

        return mean_avg_traces_per_param
Beispiel #9
0
    def _sparsify_weights(
            self, target_model: NNCFNetwork) -> List[PTInsertionCommand]:
        device = next(target_model.parameters()).device
        sparsified_module_nodes = target_model.get_weighted_original_graph_nodes(
            nncf_module_names=self.compressed_nncf_module_names)
        insertion_commands = []
        for module_node in sparsified_module_nodes:
            node_name = module_node.node_name

            if not self._should_consider_scope(node_name):
                nncf_logger.info(
                    "Ignored adding Weight Sparsifier in scope: {}".format(
                        node_name))
                continue

            nncf_logger.info(
                "Adding Weight Sparsifier in scope: {}".format(node_name))
            compression_lr_multiplier = \
                self.config.get_redefinable_global_param_value_for_algo('compression_lr_multiplier',
                                                                        self.name)
            operation = self.create_weight_sparsifying_operation(
                module_node, compression_lr_multiplier)
            hook = operation.to(device)
            insertion_commands.append(
                PTInsertionCommand(
                    PTTargetPoint(TargetType.OPERATION_WITH_WEIGHTS,
                                  target_node_name=node_name), hook,
                    TransformationPriority.SPARSIFICATION_PRIORITY))
            sparsified_module = target_model.get_containing_module(node_name)
            self._sparsified_module_info.append(
                SparseModuleInfo(node_name, sparsified_module, hook))

        return insertion_commands
Beispiel #10
0
    def set_pruning_level(self,
                          pruning_level: Union[float, Dict[int, float]],
                          run_batchnorm_adaptation: bool = False) -> None:
        """
        Set the global or groupwise pruning level in the model.
        If pruning_level is a float, the correspoding global pruning level is set in the model,
        either in terms of the percentage of filters pruned or as the percentage of flops
        removed, the latter being true in case the "prune_flops" flag of the controller is
        set to True.
        If pruning_level is a dict, the keys should correspond to layer group id's and the
        values to groupwise pruning level to be set in the model.
        """
        groupwise_pruning_levels_set = isinstance(pruning_level, dict)
        passed_pruning_level = pruning_level

        if not self.frozen:
            nncf_logger.info(
                "Computing filter importance scores and binary masks...")
            with torch.no_grad():
                if self.all_weights:
                    if groupwise_pruning_levels_set:
                        raise RuntimeError(
                            'Cannot set group-wise pruning levels with '
                            'all_weights=True')
                    # Non-uniform (global) importance-score-based pruning according
                    # to the global pruning level
                    if self.prune_flops:
                        self._set_binary_masks_for_pruned_modules_globally_by_flops_target(
                            pruning_level)
                    else:
                        self._set_binary_masks_for_pruned_modules_globally(
                            pruning_level)
                else:
                    if groupwise_pruning_levels_set:
                        group_ids = [
                            group.id for group in
                            self.pruned_module_groups_info.get_all_clusters()
                        ]
                        if set(pruning_level.keys()) != set(group_ids):
                            raise RuntimeError(
                                'Groupwise pruning level dict keys do not correspond to '
                                'layer group ids')
                    else:
                        # Pruning uniformly with the same pruning level across layers
                        if self.prune_flops:
                            # Looking for layerwise pruning level needed for the required flops pruning level
                            pruning_level = self._find_uniform_pruning_level_for_target_flops(
                                pruning_level)
                    self._set_binary_masks_for_pruned_modules_groupwise(
                        pruning_level)

        self._propagate_masks()
        if not groupwise_pruning_levels_set:
            self._pruning_level = passed_pruning_level
        else:
            self._pruning_level = self._calculate_global_weight_pruning_level()

        if run_batchnorm_adaptation:
            self._run_batchnorm_adaptation()
 def on_train_begin(self, logs: dict = None):
     nncf_stats = self._statistics_fn()
     if self._log_tensorboard:
         self._dump_to_tensorboard(
             self._prepare_for_tensorboard(nncf_stats),
             self.model.optimizer.iterations.numpy())
     if self._log_text:
         nncf_logger.info(nncf_stats.to_str())
Beispiel #12
0
 def load_best_checkpoint(self, model):
     resuming_checkpoint_path = self._best_checkpoint
     nncf_logger.info('Loading the best checkpoint found during training '
                      '{}...'.format(resuming_checkpoint_path))
     resuming_checkpoint = torch.load(resuming_checkpoint_path,
                                      map_location='cpu')
     resuming_model_state_dict = resuming_checkpoint.get(
         'state_dict', resuming_checkpoint)
     load_state(model, resuming_model_state_dict, is_resume=True)
Beispiel #13
0
def register_default_init_args(nncf_config: 'NNCFConfig',
                               train_loader: torch.utils.data.DataLoader,
                               criterion: _Loss = None,
                               criterion_fn: Callable[[Any, Any, _Loss], torch.Tensor] = None,
                               train_steps_fn: Callable[[torch.utils.data.DataLoader, torch.nn.Module,
                                                         torch.optim.Optimizer, 'CompressionAlgorithmController',
                                                         Optional[int]], type(None)] = None,
                               validate_fn: Callable[[torch.nn.Module, torch.utils.data.DataLoader],
                                                     Tuple[float, float]] = None,
                               val_loader: torch.utils.data.DataLoader = None,
                               autoq_eval_fn: Callable[[torch.nn.Module, torch.utils.data.DataLoader], float] = None,
                               model_eval_fn: Callable[[torch.nn.Module, torch.utils.data.DataLoader], float] = None,
                               distributed_callbacks: Tuple[Callable, Callable] = None,
                               execution_parameters: 'ExecutionParameters' = None,
                               legr_train_optimizer: torch.optim.Optimizer = None,
                               device: str = None, ) -> 'NNCFConfig':
    nncf_config.register_extra_structs([QuantizationRangeInitArgs(data_loader=wrap_dataloader_for_init(train_loader),
                                                                  device=device),
                                        BNAdaptationInitArgs(data_loader=wrap_dataloader_for_init(train_loader),
                                                             device=device),

                                        ])
    if train_loader and train_steps_fn and val_loader and validate_fn:
        nncf_config.register_extra_structs([LeGRInitArgs(
            train_loader=train_loader,
            train_fn=train_steps_fn,
            val_loader=val_loader,
            val_fn=validate_fn,
            train_optimizer=legr_train_optimizer,
            nncf_config=nncf_config,
        )])

    if criterion is not None:
        if not criterion_fn:
            criterion_fn = default_criterion_fn
        nncf_config.register_extra_structs([QuantizationPrecisionInitArgs(criterion_fn=criterion_fn,
                                                                          criterion=criterion,
                                                                          data_loader=train_loader,
                                                                          device=device)])

    if autoq_eval_fn is not None:
        if not val_loader:
            val_loader = train_loader
        nncf_config.register_extra_structs([AutoQPrecisionInitArgs(data_loader=val_loader,
                                                                   eval_fn=autoq_eval_fn,
                                                                   nncf_config=nncf_config)])

    if model_eval_fn is not None:
        nncf_config.register_extra_structs([ModelEvaluationArgs(eval_fn=model_eval_fn)])

    if distributed_callbacks is None:
        if execution_parameters is None:
            nncf_logger.info('Please, provide execution parameters for optimal model initialization')
        distributed_callbacks = (partial(default_distributed_wrapper, execution_parameters=execution_parameters),
                                 default_distributed_unwrapper)
    nncf_config.register_extra_structs([DistributedCallbacksArgs(*distributed_callbacks)])
    return nncf_config
Beispiel #14
0
 def prune_model(self):
     """
     Model pruner work in two stages:
     1. Mask propagation: propagate pruning masks through the graph.
     2. Applying calculated masks
     """
     nncf_logger.info('Start pruning model')
     self.mask_propagation()
     self.apply_mask()
     nncf_logger.info('Finished pruning model')
Beispiel #15
0
    def evaluate_strategy(self,
                          collected_strategy: List,
                          skip_constraint=True) -> Tuple:
        assert len(collected_strategy) == len(self.master_df)
        if skip_constraint is not True:
            collected_strategy = self._constrain_model_size(collected_strategy)
        self.master_df[
            'action'] = collected_strategy  # This must be after constraint

        if self.performant_bw:
            self._align_bw_action()
            configs_to_set = self.select_config_for_actions(
                self.master_df['action_aligned'])

            if self._dump_autoq_data or is_debug():
                self._dump_adjacent_quantizer_group_alignment()

            self.master_df['action'] = self.master_df['action_aligned']
        else:
            configs_to_set = self.select_config_for_actions(
                self.master_df['action'])

        self._apply_quantizer_configs_to_model(configs_to_set)

        for idx, qid in zip(self.master_df.index, self.master_df['qid']):
            logger.info("[Q.Env] {:50} | {}".format(
                str(self.qctrl.all_quantizations[find_qid_by_str(
                    self.qctrl, qid)]), idx))

        quantized_score = self._run_quantization_pipeline(
            finetune=self.finetune)

        current_model_size = self.model_size_calculator(
            self._get_quantizer_bitwidth())
        current_model_ratio = self.model_size_calculator.get_model_size_ratio(
            self._get_quantizer_bitwidth())

        current_model_bop_ratio = self.compression_ratio_calculator.run_for_quantizer_setup(
            self.qctrl.get_quantizer_setup_for_current_state())

        reward = self.reward(quantized_score, current_model_ratio)

        info_set = {
            'model_ratio': current_model_ratio,
            'accuracy': quantized_score,
            'model_size': current_model_size,
            'bop_ratio': current_model_bop_ratio
        }

        obs = self.get_normalized_obs(len(collected_strategy) - 1)
        done = True
        self._n_eval += 1

        return obs, reward, done, info_set
Beispiel #16
0
 def _run_initial_training_phase(model, accuracy_aware_controller, runner):
     runner.configure_optimizers()
     for _ in range(runner.initial_training_phase_epochs):
         runner.train_epoch(model, accuracy_aware_controller)
     compressed_model_accuracy = runner.validate(model)
     runner.accuracy_bugdet = compressed_model_accuracy - runner.minimal_tolerable_accuracy
     runner.add_tensorboard_scalar('val/accuracy_aware/accuracy_bugdet',
                                   runner.accuracy_bugdet,
                                   runner.cumulative_epoch_count)
     nncf_logger.info('Accuracy budget value after training is {}'.format(
         runner.accuracy_bugdet))
Beispiel #17
0
    def _run_quantization_pipeline(self, finetune=False) -> float:
        if self.qctrl.config:
            self._run_batchnorm_adaptation()

        if finetune:
            raise NotImplementedError(
                "Post-Quantization fine tuning is not implemented.")
        with torch.no_grad():
            quantized_score = self.eval_fn(self.qmodel, self.eval_loader)
            logger.info(
                "[Q.Env] Quantized Score: {:.3f}".format(quantized_score))
        return quantized_score
Beispiel #18
0
    def _evaluate_pretrained_model(self):
        logger.info("[Q.Env] Evaluating Pretrained Model")
        self.qctrl.disable_weight_quantization()
        self.qctrl.disable_activation_quantization()

        with torch.no_grad():
            self.pretrained_score = self.eval_fn(self.qmodel, self.eval_loader)
            logger.info("Pretrained Score: {:.3f}".format(
                self.pretrained_score))

        self.qctrl.enable_weight_quantization()
        self.qctrl.enable_activation_quantization()
        self.qmodel.rebuild_graph()
Beispiel #19
0
    def dump_statistics(self, model, compression_controller):
        statistics = compression_controller.statistics()

        if is_main_process():
            if self.verbose:
                nncf_logger.info(statistics.to_str())
                # dump best checkpoint for current target compression rate
            if self.dump_checkpoints:
                self.dump_checkpoint(model, compression_controller)
            for key, value in prepare_for_tensorboard(statistics).items():
                if isinstance(value, (int, float)):
                    self.add_tensorboard_scalar(
                        'compression/statistics/{0}'.format(key), value,
                        self.cumulative_epoch_count)
Beispiel #20
0
    def binarize(self, x):
        if self.training and not self.is_scale_initialized:
            # init scale using nonbinarized activation statistics
            d = x.detach().data.contiguous().view(-1)
            top_num = max(1, round(d.shape[0] * 0.001))
            topk_res = d.topk(top_num)
            scale = topk_res[0].min()
            nncf_logger.info("Binarized activation scale set to: {}".format(
                scale.item()))
            self.scale.data[:] = scale.log()
            self.is_scale_initialized = True

        x = self.bin(x, self.scale.exp(), self.threshold.sigmoid())

        return x
Beispiel #21
0
 def apply_mask(self):
     """
     Applying propagated masks for all nodes in topological order:
     1. running input_prune method for this node
     2. running output_prune method for this node
     """
     pruned_node_modules = []
     with torch.no_grad():
         for node in self._graph.topological_sort():
             node_cls = self.get_meta_operation_by_type_name(node.node_type)
             node_module = self._model.get_containing_module(node.node_name)
             if node_module not in pruned_node_modules:
                 node_cls.input_prune(self._model, node, self._graph)
                 node_cls.output_prune(self._model, node, self._graph)
                 pruned_node_modules.append(node_module)
         nncf_logger.info('Finished mask applying step')
Beispiel #22
0
    def input_prune(cls, model: NNCFNetwork, node: NNCFNode,
                    graph: NNCFGraph) -> None:
        input_mask = node.data['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)

        node_module = model.get_containing_module(node.node_name)
        old_num_clannels = int(node_module.weight.size(0))

        node_module.in_channels = int(torch.sum(bool_mask))
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])

        nncf_logger.info(
            'Pruned ConvTranspose {} by input mask. Old input filters number: {}, new filters number:'
            ' {}.'.format(node.data['key'], old_num_clannels,
                          node_module.in_channels))
Beispiel #23
0
    def _apply_minmax_init(self,
                           min_values,
                           max_values,
                           log_module_name: str = None):
        ranges = max_values - min_values
        max_range = torch.max(max_values - min_values)
        eps = 1e-2
        correction = (clamp(ranges, low=eps * max_range, high=max_range) -
                      ranges) * 0.5
        self._input_range_param_storage.data = (ranges + 2 * correction).data
        if self._is_using_log_scale_storage:
            self._input_range_param_storage.data.log_()

        self.input_low.data = (min_values - correction).data

        nncf_logger.info("Set input_low: {} and input_range: {} for {}".format(
            get_flat_tensor_contents_string(self.input_low),
            get_flat_tensor_contents_string(self.input_range),
            log_module_name))
Beispiel #24
0
 def load_best_checkpoint(self, model):
     # load checkpoint with highest compression rate and positive acc budget
     possible_checkpoint_rates = [
         comp_rate
         for (comp_rate, acc_budget) in self._compressed_training_history
         if acc_budget >= 0
     ]
     if not possible_checkpoint_rates:
         nncf_logger.warning(
             'Could not produce a compressed model satisfying the set accuracy '
             'degradation criterion during training. Increasing the number of training '
             'epochs')
     best_checkpoint_compression_rate = sorted(
         possible_checkpoint_rates)[-1]
     resuming_checkpoint_path = self._best_checkpoints[
         best_checkpoint_compression_rate]
     nncf_logger.info('Loading the best checkpoint found during training '
                      '{}...'.format(resuming_checkpoint_path))
     model.load_weights(resuming_checkpoint_path)
Beispiel #25
0
 def load_best_checkpoint(self, model):
     # load checkpoint with highest compression rate and positive acc budget
     possible_checkpoint_rates = self.get_compression_rates_with_positive_acc_budget(
     )
     if not possible_checkpoint_rates:
         nncf_logger.warning(
             'Could not produce a compressed model satisfying the set accuracy '
             'degradation criterion during training. Increasing the number of training '
             'epochs')
     best_checkpoint_compression_rate = sorted(
         possible_checkpoint_rates)[-1]
     resuming_checkpoint_path = self._best_checkpoints[
         best_checkpoint_compression_rate]
     nncf_logger.info('Loading the best checkpoint found during training '
                      '{}...'.format(resuming_checkpoint_path))
     resuming_checkpoint = torch.load(resuming_checkpoint_path,
                                      map_location='cpu')
     resuming_model_state_dict = resuming_checkpoint.get(
         'state_dict', resuming_checkpoint)
     load_state(model, resuming_model_state_dict, is_resume=True)
Beispiel #26
0
    def output_prune(cls, model: NNCFNetwork, node: NNCFNode,
                     graph: NNCFGraph) -> None:
        mask = node.data['output_mask']
        if mask is None:
            return

        bool_mask = torch.tensor(mask, dtype=torch.bool)

        node_module = model.get_containing_module(node.node_name)
        old_num_clannels = int(node_module.weight.size(0))

        node_module.out_channels = int(torch.sum(mask))
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])

        if node_module.bias is not None:
            node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned Convolution {} by pruning mask. Old output filters number: {}, new filters number:'
            ' {}.'.format(node.data['key'], old_num_clannels,
                          node_module.out_channels))
Beispiel #27
0
 def dump_checkpoint(self, model, compression_controller):
     if self._dump_checkpoint_fn is not None and is_main_process():
         self._dump_checkpoint_fn(model, compression_controller, self,
                                  self._log_dir)
     else:
         checkpoint = {
             'epoch': self.cumulative_epoch_count + 1,
             'state_dict': model.state_dict(),
             'compression_state':
             compression_controller.get_compression_state(),
             'best_metric_val': self.best_val_metric_value,
             'current_val_metric_value': self.current_val_metric_value,
             'optimizer': self.optimizer.state_dict(),
             'scheduler': compression_controller.scheduler.get_state()
         }
         checkpoint_path = osp.join(self._checkpoint_save_dir,
                                    'acc_aware_checkpoint_last.pth')
         torch.save(checkpoint, checkpoint_path)
         nncf_logger.info(
             "The checkpoint is saved in {}".format(checkpoint_path))
         self._save_best_checkpoint(checkpoint_path)
Beispiel #28
0
    def run(self,
            model,
            train_epoch_fn,
            validate_fn,
            configure_optimizers_fn=None,
            dump_checkpoint_fn=None,
            tensorboard_writer=None,
            log_dir=None):
        self.runner.initialize_training_loop_fns(train_epoch_fn, validate_fn,
                                                 configure_optimizers_fn,
                                                 dump_checkpoint_fn,
                                                 tensorboard_writer, log_dir)
        self.runner.retrieve_uncompressed_model_accuracy(model)
        uncompressed_model_accuracy = self.runner.uncompressed_model_accuracy
        self.runner.calculate_minimal_tolerable_accuracy(
            uncompressed_model_accuracy)

        self.runner.configure_optimizers()
        for epoch in range(self.runner.maximal_total_epochs):
            compressed_model_accuracy = self.runner.validate(model)
            accuracy_budget = compressed_model_accuracy - self.runner.minimal_tolerable_accuracy
            accuracy_drop = uncompressed_model_accuracy - compressed_model_accuracy
            try:
                rel_accuracy_drop = 100 * (accuracy_drop /
                                           uncompressed_model_accuracy)
            except ZeroDivisionError:
                rel_accuracy_drop = 0
            if accuracy_budget >= 0:
                if epoch == 0:
                    nncf_logger.info(
                        'The accuracy criteria is reached. '
                        'Exiting the training loop after initialization step '
                        'with compressed model accuracy value {:.4f}. Original model accuracy is {:.4f} '
                        'The absolute accuracy drop is {:.4f}. '
                        'The relative accuracy drop is {:.2f}%.'.format(
                            compressed_model_accuracy,
                            uncompressed_model_accuracy, accuracy_drop,
                            rel_accuracy_drop))
                    self.runner.dump_statistics(model,
                                                self.compression_controller)
                    return model
                nncf_logger.info(
                    'The accuracy criteria is reached. '
                    'Exiting the training loop on epoch {} with '
                    'compressed model accuracy value {:.4f}. Original model accuracy is {:.4f} '
                    'The absolute accuracy drop is {:.4f}. '
                    'The relative accuracy drop is {:.2f}%.'.format(
                        epoch, compressed_model_accuracy,
                        uncompressed_model_accuracy, accuracy_drop,
                        rel_accuracy_drop))
                return model
            nncf_logger.info('The absolute accuracy drop is {:.4f}. '
                             'The relative accuracy drop is {:.2f}%.'.format(
                                 accuracy_drop, rel_accuracy_drop))
            self.runner.train_epoch(model, self.compression_controller)
            self.runner.dump_statistics(model, self.compression_controller)

        self.runner.load_best_checkpoint(model)
        return model
Beispiel #29
0
    def input_prune(cls, model: NNCFNetwork, node: NNCFNode,
                    graph: NNCFGraph) -> None:
        input_mask = node.data['input_masks'][0]
        if input_mask is None:
            return

        node_module = model.get_containing_module(node.node_name)

        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        old_num_clannels = int(node_module.weight.size(0))
        new_num_channels = int(torch.sum(input_mask))

        node_module.num_channels = new_num_channels
        node_module.num_groups = new_num_channels

        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])
        node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned GroupNorm {} by input mask. Old num features: {}, new num features:'
            ' {}.'.format(node.data['key'], old_num_clannels,
                          new_num_channels))
Beispiel #30
0
    def input_prune(cls, model: NNCFNetwork, node: NNCFNode,
                    graph: NNCFGraph) -> None:
        input_mask = node.data['input_masks'][0]
        if input_mask is None:
            return

        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        node_module = model.get_containing_module(node.node_name)

        if isinstance(node_module, tuple(NNCF_WRAPPED_USER_MODULES_DICT)):
            assert node_module.target_weight_dim_for_compression == 0, \
                "Implemented only for target_weight_dim_for_compression == 0"
            old_num_clannels = int(node_module.weight.size(0))
            new_num_channels = int(torch.sum(input_mask))
            node_module.weight = torch.nn.Parameter(
                node_module.weight[bool_mask])
            node_module.n_channels = new_num_channels

            nncf_logger.info(
                'Pruned Elementwise {} by input mask. Old num features: {}, new num features:'
                ' {}.'.format(node.data['key'], old_num_clannels,
                              new_num_channels))