def load_state(model: torch.nn.Module, state_dict_to_load: dict, is_resume: bool = False, keys_to_ignore: List[str] = None) -> int: """ Used to load a checkpoint containing a compressed model into an NNCFNetwork object, but can be used for any PyTorch module as well. Will do matching of state_dict_to_load parameters to the model's state_dict parameters while discarding irrelevant prefixes added during wrapping in NNCFNetwork or DataParallel/DistributedDataParallel objects, and load the matched parameters from the state_dict_to_load into the model's state dict. :param model: The target module for the state_dict_to_load to be loaded to. :param state_dict_to_load: A state dict containing the parameters to be loaded into the model. :param is_resume: Determines the behavior when the function cannot do a successful parameter match when loading. If True, the function will raise an exception if it cannot match the state_dict_to_load parameters to the model's parameters (i.e. if some parameters required by model are missing in state_dict_to_load, or if state_dict_to_load has parameters that could not be matched to model parameters, or if the shape of parameters is not matching). If False, the exception won't be raised. Usually is_resume is specified as False when loading uncompressed model's weights into the model with compression algorithms already applied, and as True when loading a compressed model's weights into the model with compression algorithms applied to evaluate the model. :param keys_to_ignore: A list of parameter names that should be skipped from matching process. :return: The number of state_dict_to_load entries successfully matched and loaded into model. """ model_state_dict = model.state_dict() maybe_convert_legacy_names_in_model_state(state_dict_to_load) key_matcher = KeyMatcher(is_resume, state_dict_to_load, model_state_dict, keys_to_ignore) new_dict = key_matcher.run() num_loaded_params = len(new_dict) key_matcher.handle_problematic_keys() nncf_logger.info("Loaded {}/{} parameters".format(num_loaded_params, len(model_state_dict.items()))) model.load_state_dict(new_dict, strict=False) return num_loaded_params
def _filter_groups(self, pruned_nodes_clusterization: Clusterization, can_prune: Dict[int, PruningAnalysisDecision]) -> None: """ Check whether all nodes in group can be pruned based on user-defined constraints and connections inside the network. Otherwise the whole group cannot be pruned and will be removed the clusterization. :param pruned_nodes_clusterization: Clusterization with pruning nodes groups. :param can_prune: Can this node be pruned or not. """ for cluster in pruned_nodes_clusterization.get_all_clusters(): nodes_decisions = [ can_prune[node.node_id] for node in cluster.elements ] nodes_names = [node.node_name for node in cluster.elements] if not all(nodes_decisions): cannot_prune_messages = [] for name, decision in zip(nodes_names, nodes_decisions): if not decision: message = PruningAnalysisReason.message(name, decision) if message: cannot_prune_messages.append(message) nncf_logger.info( 'Group of nodes [{}] can\'t be pruned, because some nodes should\'t be pruned, ' 'error messages for this nodes: {}.'.format( ', '.join(nodes_names), ', '.join(cannot_prune_messages))) pruned_nodes_clusterization.delete_cluster(cluster.id) else: nncf_logger.info( 'Group of nodes [{}] will be pruned together.'.format( ", ".join(nodes_names)))
def dump_metric_MB(self, metric_per_qconfig_sequence: List[Tensor]): import matplotlib.pyplot as plt list_to_plot = [cm.item() for cm in metric_per_qconfig_sequence] fig = plt.figure() fig.suptitle('Pareto Frontier') ax = fig.add_subplot(2, 1, 1) ax.set_yscale('log') ax.set_xlabel('Model Size (MB)') ax.set_ylabel('Metric value (total perturbation)') ax.scatter(self._model_sizes, list_to_plot, s=20, facecolors='none', edgecolors='r') cm = torch.Tensor(metric_per_qconfig_sequence) cm_m = cm.median().item() qconfig_index = metric_per_qconfig_sequence.index(cm_m) ms_m = self._model_sizes[qconfig_index] ax.scatter(ms_m, cm_m, s=30, facecolors='none', edgecolors='b', label='median from all metrics') ax.legend() plt.savefig(os.path.join(self._dump_dir, 'Pareto_Frontier')) nncf_logger.info( 'Distribution of HAWQ metrics: min_value={:.3f}, max_value={:.3f}, median_value={:.3f}, ' 'median_index={}, total_number={}'.format( cm.min().item(), cm.max().item(), cm_m, qconfig_index, len(metric_per_qconfig_sequence)))
def set_pruning_level(self, pruning_level: float, run_batchnorm_adaptation: bool = False): """ Setup pruning masks in accordance to provided pruning rate :param pruning_level: pruning ration :return: """ # Pruning rate from scheduler can be percentage of params that should be pruned self.pruning_rate = pruning_level if not self.frozen: nncf_logger.info( 'Computing filter importance scores and binary masks...') if self.all_weights: if self.prune_flops: self._set_binary_masks_for_pruned_modules_globally_by_flops_target( pruning_level) else: self._set_binary_masks_for_pruned_layers_globally( pruning_level) else: if self.prune_flops: # Looking for a layerwise pruning rate needed for the required flops pruning rate pruning_level = self._find_uniform_pruning_level_for_target_flops( pruning_level) self._set_binary_masks_for_pruned_layers_groupwise( pruning_level) if run_batchnorm_adaptation: self._run_batchnorm_adaptation()
def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: output_mask = node.data['output_mask'] if output_mask is None: return bool_mask = torch.tensor(output_mask, dtype=torch.bool) new_num_channels = int(torch.sum(bool_mask)) node_module = model.get_containing_module(node.node_name) old_num_clannels = int(node_module.weight.size(1)) in_channels = node_module.weight.size(0) broadcasted_mask = bool_mask.repeat(in_channels).view( in_channels, bool_mask.size(0)) new_weight_shape = list(node_module.weight.shape) new_weight_shape[1] = new_num_channels node_module.out_channels = new_num_channels node_module.weight = torch.nn.Parameter( node_module.weight[broadcasted_mask].view(new_weight_shape)) if node_module.bias is not None: node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask]) nncf_logger.info( 'Pruned ConvTranspose {} by pruning mask. Old output filters number: {}, new filters number:' ' {}.'.format(node.data['key'], old_num_clannels, node_module.out_channels))
def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: input_mask = node.data['input_masks'][0] if input_mask is None: return bool_mask = torch.tensor(input_mask, dtype=torch.bool) new_num_channels = int(torch.sum(input_mask)) is_depthwise = is_prunable_depthwise_conv(node) node_module = model.get_containing_module(node.node_name) old_num_channels = int(node_module.weight.size(1)) if is_depthwise: # In depthwise case prune output channels by input mask, here only fix for new number of input channels node_module.groups = new_num_channels node_module.in_channels = new_num_channels old_num_channels = int(node_module.weight.size(0)) else: out_channels = node_module.weight.size(0) broadcasted_mask = bool_mask.repeat(out_channels).view( out_channels, bool_mask.size(0)) new_weight_shape = list(node_module.weight.shape) new_weight_shape[1] = new_num_channels node_module.in_channels = new_num_channels node_module.weight = torch.nn.Parameter( node_module.weight[broadcasted_mask].view(new_weight_shape)) nncf_logger.info( 'Pruned Convolution {} by input mask. Old input filters number: {}, new filters number:' ' {}.'.format(node.data['key'], old_num_channels, new_num_channels))
def _apply_minmax_init(self, min_values, max_values, log_module_name: str = None): if torch.any(torch.eq(min_values, np.inf)) or torch.any( torch.eq(max_values, -np.inf)): raise AttributeError( 'Statistics is not collected for {}'.format(log_module_name)) sign = torch.any(torch.lt(min_values, 0)) if self._signedness_to_force is not None and sign != self._signedness_to_force: nncf_logger.warning("Forcing signed to {} for module {}".format( self._signedness_to_force, log_module_name)) sign = self._signedness_to_force self.signed = int(sign) abs_max = torch.max(torch.abs(max_values), torch.abs(min_values)) SCALE_LOWER_THRESHOLD = 0.1 mask = torch.gt(abs_max, SCALE_LOWER_THRESHOLD) self._scale_param_storage.data = torch.where( mask, abs_max, SCALE_LOWER_THRESHOLD * torch.ones_like(self._scale_param_storage)) if self._is_using_log_scale_storage: self._scale_param_storage.data.log_() nncf_logger.info("Set sign: {} and scale: {} for {}".format( self.signed, get_flat_tensor_contents_string(self.scale), log_module_name))
def get_average_traces(self, max_iter=500, tolerance=1e-5) -> Tensor: """ Estimates average hessian trace for each parameter :param max_iter: maximum number of iterations for Hutchinson algorithm :param tolerance: - minimum relative tolerance for stopping the algorithm. It's calculated between mean average trace from previous iteration and current one. :return: Tensor with average hessian trace per parameter """ avg_total_trace = 0. avg_traces_per_iter = [] # type: List[Tensor] mean_avg_traces_per_param = None for i in range(max_iter): avg_traces_per_iter.append(self._calc_avg_traces_per_param()) mean_avg_traces_per_param = self._get_mean(avg_traces_per_iter) mean_avg_total_trace = torch.sum(mean_avg_traces_per_param) diff_avg = abs(mean_avg_total_trace - avg_total_trace) / ( avg_total_trace + self._diff_eps) if diff_avg < tolerance: return mean_avg_traces_per_param avg_total_trace = mean_avg_total_trace nncf_logger.info('{}# difference_avg={} avg_trace={}'.format( i, diff_avg, avg_total_trace)) return mean_avg_traces_per_param
def _sparsify_weights( self, target_model: NNCFNetwork) -> List[PTInsertionCommand]: device = next(target_model.parameters()).device sparsified_module_nodes = target_model.get_weighted_original_graph_nodes( nncf_module_names=self.compressed_nncf_module_names) insertion_commands = [] for module_node in sparsified_module_nodes: node_name = module_node.node_name if not self._should_consider_scope(node_name): nncf_logger.info( "Ignored adding Weight Sparsifier in scope: {}".format( node_name)) continue nncf_logger.info( "Adding Weight Sparsifier in scope: {}".format(node_name)) compression_lr_multiplier = \ self.config.get_redefinable_global_param_value_for_algo('compression_lr_multiplier', self.name) operation = self.create_weight_sparsifying_operation( module_node, compression_lr_multiplier) hook = operation.to(device) insertion_commands.append( PTInsertionCommand( PTTargetPoint(TargetType.OPERATION_WITH_WEIGHTS, target_node_name=node_name), hook, TransformationPriority.SPARSIFICATION_PRIORITY)) sparsified_module = target_model.get_containing_module(node_name) self._sparsified_module_info.append( SparseModuleInfo(node_name, sparsified_module, hook)) return insertion_commands
def set_pruning_level(self, pruning_level: Union[float, Dict[int, float]], run_batchnorm_adaptation: bool = False) -> None: """ Set the global or groupwise pruning level in the model. If pruning_level is a float, the correspoding global pruning level is set in the model, either in terms of the percentage of filters pruned or as the percentage of flops removed, the latter being true in case the "prune_flops" flag of the controller is set to True. If pruning_level is a dict, the keys should correspond to layer group id's and the values to groupwise pruning level to be set in the model. """ groupwise_pruning_levels_set = isinstance(pruning_level, dict) passed_pruning_level = pruning_level if not self.frozen: nncf_logger.info( "Computing filter importance scores and binary masks...") with torch.no_grad(): if self.all_weights: if groupwise_pruning_levels_set: raise RuntimeError( 'Cannot set group-wise pruning levels with ' 'all_weights=True') # Non-uniform (global) importance-score-based pruning according # to the global pruning level if self.prune_flops: self._set_binary_masks_for_pruned_modules_globally_by_flops_target( pruning_level) else: self._set_binary_masks_for_pruned_modules_globally( pruning_level) else: if groupwise_pruning_levels_set: group_ids = [ group.id for group in self.pruned_module_groups_info.get_all_clusters() ] if set(pruning_level.keys()) != set(group_ids): raise RuntimeError( 'Groupwise pruning level dict keys do not correspond to ' 'layer group ids') else: # Pruning uniformly with the same pruning level across layers if self.prune_flops: # Looking for layerwise pruning level needed for the required flops pruning level pruning_level = self._find_uniform_pruning_level_for_target_flops( pruning_level) self._set_binary_masks_for_pruned_modules_groupwise( pruning_level) self._propagate_masks() if not groupwise_pruning_levels_set: self._pruning_level = passed_pruning_level else: self._pruning_level = self._calculate_global_weight_pruning_level() if run_batchnorm_adaptation: self._run_batchnorm_adaptation()
def on_train_begin(self, logs: dict = None): nncf_stats = self._statistics_fn() if self._log_tensorboard: self._dump_to_tensorboard( self._prepare_for_tensorboard(nncf_stats), self.model.optimizer.iterations.numpy()) if self._log_text: nncf_logger.info(nncf_stats.to_str())
def load_best_checkpoint(self, model): resuming_checkpoint_path = self._best_checkpoint nncf_logger.info('Loading the best checkpoint found during training ' '{}...'.format(resuming_checkpoint_path)) resuming_checkpoint = torch.load(resuming_checkpoint_path, map_location='cpu') resuming_model_state_dict = resuming_checkpoint.get( 'state_dict', resuming_checkpoint) load_state(model, resuming_model_state_dict, is_resume=True)
def register_default_init_args(nncf_config: 'NNCFConfig', train_loader: torch.utils.data.DataLoader, criterion: _Loss = None, criterion_fn: Callable[[Any, Any, _Loss], torch.Tensor] = None, train_steps_fn: Callable[[torch.utils.data.DataLoader, torch.nn.Module, torch.optim.Optimizer, 'CompressionAlgorithmController', Optional[int]], type(None)] = None, validate_fn: Callable[[torch.nn.Module, torch.utils.data.DataLoader], Tuple[float, float]] = None, val_loader: torch.utils.data.DataLoader = None, autoq_eval_fn: Callable[[torch.nn.Module, torch.utils.data.DataLoader], float] = None, model_eval_fn: Callable[[torch.nn.Module, torch.utils.data.DataLoader], float] = None, distributed_callbacks: Tuple[Callable, Callable] = None, execution_parameters: 'ExecutionParameters' = None, legr_train_optimizer: torch.optim.Optimizer = None, device: str = None, ) -> 'NNCFConfig': nncf_config.register_extra_structs([QuantizationRangeInitArgs(data_loader=wrap_dataloader_for_init(train_loader), device=device), BNAdaptationInitArgs(data_loader=wrap_dataloader_for_init(train_loader), device=device), ]) if train_loader and train_steps_fn and val_loader and validate_fn: nncf_config.register_extra_structs([LeGRInitArgs( train_loader=train_loader, train_fn=train_steps_fn, val_loader=val_loader, val_fn=validate_fn, train_optimizer=legr_train_optimizer, nncf_config=nncf_config, )]) if criterion is not None: if not criterion_fn: criterion_fn = default_criterion_fn nncf_config.register_extra_structs([QuantizationPrecisionInitArgs(criterion_fn=criterion_fn, criterion=criterion, data_loader=train_loader, device=device)]) if autoq_eval_fn is not None: if not val_loader: val_loader = train_loader nncf_config.register_extra_structs([AutoQPrecisionInitArgs(data_loader=val_loader, eval_fn=autoq_eval_fn, nncf_config=nncf_config)]) if model_eval_fn is not None: nncf_config.register_extra_structs([ModelEvaluationArgs(eval_fn=model_eval_fn)]) if distributed_callbacks is None: if execution_parameters is None: nncf_logger.info('Please, provide execution parameters for optimal model initialization') distributed_callbacks = (partial(default_distributed_wrapper, execution_parameters=execution_parameters), default_distributed_unwrapper) nncf_config.register_extra_structs([DistributedCallbacksArgs(*distributed_callbacks)]) return nncf_config
def prune_model(self): """ Model pruner work in two stages: 1. Mask propagation: propagate pruning masks through the graph. 2. Applying calculated masks """ nncf_logger.info('Start pruning model') self.mask_propagation() self.apply_mask() nncf_logger.info('Finished pruning model')
def evaluate_strategy(self, collected_strategy: List, skip_constraint=True) -> Tuple: assert len(collected_strategy) == len(self.master_df) if skip_constraint is not True: collected_strategy = self._constrain_model_size(collected_strategy) self.master_df[ 'action'] = collected_strategy # This must be after constraint if self.performant_bw: self._align_bw_action() configs_to_set = self.select_config_for_actions( self.master_df['action_aligned']) if self._dump_autoq_data or is_debug(): self._dump_adjacent_quantizer_group_alignment() self.master_df['action'] = self.master_df['action_aligned'] else: configs_to_set = self.select_config_for_actions( self.master_df['action']) self._apply_quantizer_configs_to_model(configs_to_set) for idx, qid in zip(self.master_df.index, self.master_df['qid']): logger.info("[Q.Env] {:50} | {}".format( str(self.qctrl.all_quantizations[find_qid_by_str( self.qctrl, qid)]), idx)) quantized_score = self._run_quantization_pipeline( finetune=self.finetune) current_model_size = self.model_size_calculator( self._get_quantizer_bitwidth()) current_model_ratio = self.model_size_calculator.get_model_size_ratio( self._get_quantizer_bitwidth()) current_model_bop_ratio = self.compression_ratio_calculator.run_for_quantizer_setup( self.qctrl.get_quantizer_setup_for_current_state()) reward = self.reward(quantized_score, current_model_ratio) info_set = { 'model_ratio': current_model_ratio, 'accuracy': quantized_score, 'model_size': current_model_size, 'bop_ratio': current_model_bop_ratio } obs = self.get_normalized_obs(len(collected_strategy) - 1) done = True self._n_eval += 1 return obs, reward, done, info_set
def _run_initial_training_phase(model, accuracy_aware_controller, runner): runner.configure_optimizers() for _ in range(runner.initial_training_phase_epochs): runner.train_epoch(model, accuracy_aware_controller) compressed_model_accuracy = runner.validate(model) runner.accuracy_bugdet = compressed_model_accuracy - runner.minimal_tolerable_accuracy runner.add_tensorboard_scalar('val/accuracy_aware/accuracy_bugdet', runner.accuracy_bugdet, runner.cumulative_epoch_count) nncf_logger.info('Accuracy budget value after training is {}'.format( runner.accuracy_bugdet))
def _run_quantization_pipeline(self, finetune=False) -> float: if self.qctrl.config: self._run_batchnorm_adaptation() if finetune: raise NotImplementedError( "Post-Quantization fine tuning is not implemented.") with torch.no_grad(): quantized_score = self.eval_fn(self.qmodel, self.eval_loader) logger.info( "[Q.Env] Quantized Score: {:.3f}".format(quantized_score)) return quantized_score
def _evaluate_pretrained_model(self): logger.info("[Q.Env] Evaluating Pretrained Model") self.qctrl.disable_weight_quantization() self.qctrl.disable_activation_quantization() with torch.no_grad(): self.pretrained_score = self.eval_fn(self.qmodel, self.eval_loader) logger.info("Pretrained Score: {:.3f}".format( self.pretrained_score)) self.qctrl.enable_weight_quantization() self.qctrl.enable_activation_quantization() self.qmodel.rebuild_graph()
def dump_statistics(self, model, compression_controller): statistics = compression_controller.statistics() if is_main_process(): if self.verbose: nncf_logger.info(statistics.to_str()) # dump best checkpoint for current target compression rate if self.dump_checkpoints: self.dump_checkpoint(model, compression_controller) for key, value in prepare_for_tensorboard(statistics).items(): if isinstance(value, (int, float)): self.add_tensorboard_scalar( 'compression/statistics/{0}'.format(key), value, self.cumulative_epoch_count)
def binarize(self, x): if self.training and not self.is_scale_initialized: # init scale using nonbinarized activation statistics d = x.detach().data.contiguous().view(-1) top_num = max(1, round(d.shape[0] * 0.001)) topk_res = d.topk(top_num) scale = topk_res[0].min() nncf_logger.info("Binarized activation scale set to: {}".format( scale.item())) self.scale.data[:] = scale.log() self.is_scale_initialized = True x = self.bin(x, self.scale.exp(), self.threshold.sigmoid()) return x
def apply_mask(self): """ Applying propagated masks for all nodes in topological order: 1. running input_prune method for this node 2. running output_prune method for this node """ pruned_node_modules = [] with torch.no_grad(): for node in self._graph.topological_sort(): node_cls = self.get_meta_operation_by_type_name(node.node_type) node_module = self._model.get_containing_module(node.node_name) if node_module not in pruned_node_modules: node_cls.input_prune(self._model, node, self._graph) node_cls.output_prune(self._model, node, self._graph) pruned_node_modules.append(node_module) nncf_logger.info('Finished mask applying step')
def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: input_mask = node.data['input_masks'][0] if input_mask is None: return bool_mask = torch.tensor(input_mask, dtype=torch.bool) node_module = model.get_containing_module(node.node_name) old_num_clannels = int(node_module.weight.size(0)) node_module.in_channels = int(torch.sum(bool_mask)) node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask]) nncf_logger.info( 'Pruned ConvTranspose {} by input mask. Old input filters number: {}, new filters number:' ' {}.'.format(node.data['key'], old_num_clannels, node_module.in_channels))
def _apply_minmax_init(self, min_values, max_values, log_module_name: str = None): ranges = max_values - min_values max_range = torch.max(max_values - min_values) eps = 1e-2 correction = (clamp(ranges, low=eps * max_range, high=max_range) - ranges) * 0.5 self._input_range_param_storage.data = (ranges + 2 * correction).data if self._is_using_log_scale_storage: self._input_range_param_storage.data.log_() self.input_low.data = (min_values - correction).data nncf_logger.info("Set input_low: {} and input_range: {} for {}".format( get_flat_tensor_contents_string(self.input_low), get_flat_tensor_contents_string(self.input_range), log_module_name))
def load_best_checkpoint(self, model): # load checkpoint with highest compression rate and positive acc budget possible_checkpoint_rates = [ comp_rate for (comp_rate, acc_budget) in self._compressed_training_history if acc_budget >= 0 ] if not possible_checkpoint_rates: nncf_logger.warning( 'Could not produce a compressed model satisfying the set accuracy ' 'degradation criterion during training. Increasing the number of training ' 'epochs') best_checkpoint_compression_rate = sorted( possible_checkpoint_rates)[-1] resuming_checkpoint_path = self._best_checkpoints[ best_checkpoint_compression_rate] nncf_logger.info('Loading the best checkpoint found during training ' '{}...'.format(resuming_checkpoint_path)) model.load_weights(resuming_checkpoint_path)
def load_best_checkpoint(self, model): # load checkpoint with highest compression rate and positive acc budget possible_checkpoint_rates = self.get_compression_rates_with_positive_acc_budget( ) if not possible_checkpoint_rates: nncf_logger.warning( 'Could not produce a compressed model satisfying the set accuracy ' 'degradation criterion during training. Increasing the number of training ' 'epochs') best_checkpoint_compression_rate = sorted( possible_checkpoint_rates)[-1] resuming_checkpoint_path = self._best_checkpoints[ best_checkpoint_compression_rate] nncf_logger.info('Loading the best checkpoint found during training ' '{}...'.format(resuming_checkpoint_path)) resuming_checkpoint = torch.load(resuming_checkpoint_path, map_location='cpu') resuming_model_state_dict = resuming_checkpoint.get( 'state_dict', resuming_checkpoint) load_state(model, resuming_model_state_dict, is_resume=True)
def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: mask = node.data['output_mask'] if mask is None: return bool_mask = torch.tensor(mask, dtype=torch.bool) node_module = model.get_containing_module(node.node_name) old_num_clannels = int(node_module.weight.size(0)) node_module.out_channels = int(torch.sum(mask)) node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask]) if node_module.bias is not None: node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask]) nncf_logger.info( 'Pruned Convolution {} by pruning mask. Old output filters number: {}, new filters number:' ' {}.'.format(node.data['key'], old_num_clannels, node_module.out_channels))
def dump_checkpoint(self, model, compression_controller): if self._dump_checkpoint_fn is not None and is_main_process(): self._dump_checkpoint_fn(model, compression_controller, self, self._log_dir) else: checkpoint = { 'epoch': self.cumulative_epoch_count + 1, 'state_dict': model.state_dict(), 'compression_state': compression_controller.get_compression_state(), 'best_metric_val': self.best_val_metric_value, 'current_val_metric_value': self.current_val_metric_value, 'optimizer': self.optimizer.state_dict(), 'scheduler': compression_controller.scheduler.get_state() } checkpoint_path = osp.join(self._checkpoint_save_dir, 'acc_aware_checkpoint_last.pth') torch.save(checkpoint, checkpoint_path) nncf_logger.info( "The checkpoint is saved in {}".format(checkpoint_path)) self._save_best_checkpoint(checkpoint_path)
def run(self, model, train_epoch_fn, validate_fn, configure_optimizers_fn=None, dump_checkpoint_fn=None, tensorboard_writer=None, log_dir=None): self.runner.initialize_training_loop_fns(train_epoch_fn, validate_fn, configure_optimizers_fn, dump_checkpoint_fn, tensorboard_writer, log_dir) self.runner.retrieve_uncompressed_model_accuracy(model) uncompressed_model_accuracy = self.runner.uncompressed_model_accuracy self.runner.calculate_minimal_tolerable_accuracy( uncompressed_model_accuracy) self.runner.configure_optimizers() for epoch in range(self.runner.maximal_total_epochs): compressed_model_accuracy = self.runner.validate(model) accuracy_budget = compressed_model_accuracy - self.runner.minimal_tolerable_accuracy accuracy_drop = uncompressed_model_accuracy - compressed_model_accuracy try: rel_accuracy_drop = 100 * (accuracy_drop / uncompressed_model_accuracy) except ZeroDivisionError: rel_accuracy_drop = 0 if accuracy_budget >= 0: if epoch == 0: nncf_logger.info( 'The accuracy criteria is reached. ' 'Exiting the training loop after initialization step ' 'with compressed model accuracy value {:.4f}. Original model accuracy is {:.4f} ' 'The absolute accuracy drop is {:.4f}. ' 'The relative accuracy drop is {:.2f}%.'.format( compressed_model_accuracy, uncompressed_model_accuracy, accuracy_drop, rel_accuracy_drop)) self.runner.dump_statistics(model, self.compression_controller) return model nncf_logger.info( 'The accuracy criteria is reached. ' 'Exiting the training loop on epoch {} with ' 'compressed model accuracy value {:.4f}. Original model accuracy is {:.4f} ' 'The absolute accuracy drop is {:.4f}. ' 'The relative accuracy drop is {:.2f}%.'.format( epoch, compressed_model_accuracy, uncompressed_model_accuracy, accuracy_drop, rel_accuracy_drop)) return model nncf_logger.info('The absolute accuracy drop is {:.4f}. ' 'The relative accuracy drop is {:.2f}%.'.format( accuracy_drop, rel_accuracy_drop)) self.runner.train_epoch(model, self.compression_controller) self.runner.dump_statistics(model, self.compression_controller) self.runner.load_best_checkpoint(model) return model
def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: input_mask = node.data['input_masks'][0] if input_mask is None: return node_module = model.get_containing_module(node.node_name) bool_mask = torch.tensor(input_mask, dtype=torch.bool) old_num_clannels = int(node_module.weight.size(0)) new_num_channels = int(torch.sum(input_mask)) node_module.num_channels = new_num_channels node_module.num_groups = new_num_channels node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask]) node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask]) nncf_logger.info( 'Pruned GroupNorm {} by input mask. Old num features: {}, new num features:' ' {}.'.format(node.data['key'], old_num_clannels, new_num_channels))
def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: input_mask = node.data['input_masks'][0] if input_mask is None: return bool_mask = torch.tensor(input_mask, dtype=torch.bool) node_module = model.get_containing_module(node.node_name) if isinstance(node_module, tuple(NNCF_WRAPPED_USER_MODULES_DICT)): assert node_module.target_weight_dim_for_compression == 0, \ "Implemented only for target_weight_dim_for_compression == 0" old_num_clannels = int(node_module.weight.size(0)) new_num_channels = int(torch.sum(input_mask)) node_module.weight = torch.nn.Parameter( node_module.weight[bool_mask]) node_module.n_channels = new_num_channels nncf_logger.info( 'Pruned Elementwise {} by input mask. Old num features: {}, new num features:' ' {}.'.format(node.data['key'], old_num_clannels, new_num_channels))