def get_average_traces(self, max_iter=100, tolerance=1e-3) -> Tensor: """ Estimates average hessian trace for each parameter :param max_iter: maximum number of iterations for Hutchinson algorithm :param tolerance: - minimum relative tolerance for stopping the algorithm. It's calculated between mean average trace from previous iteration and current one. :return: Tensor with average hessian trace per parameter """ avg_total_trace = 0. avg_traces_per_iter = [] # type: List[Tensor] mean_avg_traces_per_param = None for i in range(max_iter): avg_traces_per_iter.append(self._calc_avg_traces_per_param()) mean_avg_traces_per_param = self._get_mean(avg_traces_per_iter) mean_avg_total_trace = torch.sum(mean_avg_traces_per_param) diff_avg = abs(mean_avg_total_trace - avg_total_trace) / ( avg_total_trace + self._diff_eps) if diff_avg < tolerance: return mean_avg_traces_per_param avg_total_trace = mean_avg_total_trace nncf_logger.info('{}# difference_avg={} avg_trace={}'.format( i, diff_avg, avg_total_trace)) return mean_avg_traces_per_param
def output_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): output_mask = nx_node['output_mask'] if output_mask is None: return bool_mask = torch.tensor(output_mask, dtype=torch.bool) new_num_channels = int(torch.sum(bool_mask)) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) old_num_clannels = int(node_module.weight.size(1)) in_channels = node_module.weight.size(0) broadcasted_mask = bool_mask.repeat(in_channels).view( in_channels, bool_mask.size(0)) new_weight_shape = list(node_module.weight.shape) new_weight_shape[1] = new_num_channels node_module.out_channels = new_num_channels node_module.weight = torch.nn.Parameter( node_module.weight[broadcasted_mask].view(new_weight_shape)) if node_module.bias is not None: node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask]) nncf_logger.info( 'Pruned ConvTranspose {} by pruning mask. Old output filters number: {}, new filters number:' ' {}.'.format(nx_node['key'], old_num_clannels, node_module.out_channels))
def load_state(model: torch.nn.Module, saved_state_dict: dict, is_resume: bool = False) -> int: """ Used to load a checkpoint containing a compressed model into an NNCFNetwork object, but can be used for any PyTorch module as well. Will do matching of saved_state_dict parameters to the model's state_dict parameters while discarding irrelevant prefixes added during wrapping in NNCFNetwork or DataParallel/DistributedDataParallel objects, and load the matched parameters from the saved_state_dict into the model's state dict. :param model: The target module for the saved_state_dict to be loaded to. :param saved_state_dict: A state dict containing the parameters to be loaded into the model. :param is_resume: Determines the behavior when the function cannot do a successful parameter match when loading. If True, the function will raise an exception if it cannot match the saved_state_dict parameters to the model's parameters (i.e. if some parameters required by model are missing in saved_state_dict, or if saved_state_dict has parameters that could not be matched to model parameters, or if the shape of parameters is not matching). If False, the exception won't be raised. Usually is_resume is specified as False when loading uncompressed model's weights into the model with compression algorithms already applied, and as True when loading a compressed model's weights into the model with compression algorithms applied to evaluate the model. :return: The number of saved_state_dict entries successfully matched and loaded into model. """ def key_normalizer(key): new_key = key match = re.search('(pre_ops|post_ops)\\.(\\d+?)\\.op', key) return new_key if not match else new_key.replace(match.group(), 'operation') if 'state_dict' in saved_state_dict: saved_state_dict = saved_state_dict['state_dict'] state_dict = model.state_dict() new_dict, num_loaded_layers, problematic_keys = match_keys(is_resume, saved_state_dict, state_dict, key_normalizer) num_saved_layers = len(saved_state_dict.items()) process_problematic_keys(is_resume, problematic_keys, num_loaded_layers == num_saved_layers) nncf_logger.info("Loaded {}/{} layers".format(num_loaded_layers, len(state_dict.items()))) model.load_state_dict(new_dict, strict=False) return num_loaded_layers
def _apply_minmax_init(self, min_values, max_values, log_module_name: str = None): if torch.any(torch.eq(min_values, np.inf)) or torch.any( torch.eq(max_values, -np.inf)): raise AttributeError( 'Statistics is not collected for {}'.format(log_module_name)) sign = torch.any(torch.lt(min_values, 0)) if self.signedness_to_force is not None and sign != self.signedness_to_force: nncf_logger.warning("Forcing signed to {} for module {}".format( self.signedness_to_force, log_module_name)) sign = self.signedness_to_force self.signed = int(sign) abs_max = torch.max(torch.abs(max_values), torch.abs(min_values)) SCALE_LOWER_THRESHOLD = 0.1 mask = torch.gt(abs_max, SCALE_LOWER_THRESHOLD) self._scale_param_storage.data = torch.where( mask, abs_max, SCALE_LOWER_THRESHOLD * torch.ones_like(self._scale_param_storage)) if self._is_using_log_scale_storage: self._scale_param_storage.data.log_() nncf_logger.info("Set sign: {} and scale: {} for {}".format( self.signed, get_flat_tensor_contents_string(self.scale), log_module_name))
def _sparsify_weights(self, target_model: NNCFNetwork) -> List[InsertionCommand]: device = next(target_model.parameters()).device sparsified_modules = target_model.get_nncf_modules_by_module_names( self.compressed_nncf_module_names) insertion_commands = [] for module_scope, module in sparsified_modules.items(): scope_str = str(module_scope) if not self._should_consider_scope(scope_str): nncf_logger.info( "Ignored adding Weight Sparsifier in scope: {}".format( scope_str)) continue nncf_logger.info( "Adding Weight Sparsifier in scope: {}".format(scope_str)) operation = self.create_weight_sparsifying_operation(module) hook = UpdateWeight(operation).to(device) insertion_commands.append( InsertionCommand( InsertionPoint(InsertionType.NNCF_MODULE_PRE_OP, module_scope=module_scope), hook, OperationPriority.SPARSIFICATION_PRIORITY)) self._sparsified_module_info.append( SparseModuleInfo(scope_str, module, hook.operand)) return insertion_commands
def apply_mask(self): """ Applying propagated masks for all nodes in topological order: 1. running input_prune method for this node 2. running output_prune method for this node """ sorted_nodes = [ self.nx_graph.nodes[name] for name in nx.topological_sort(self.nx_graph) ] pruned_node_modules = list() with torch.no_grad(): for node in sorted_nodes: node_type = self.graph.node_type_fn(node) node_cls = self.get_class_by_type_name(node_type)() nncf_node = self.graph._nx_node_to_nncf_node(node) node_module = self.model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) # Some modules can be associated with several nodes if node_module not in pruned_node_modules: node_cls.input_prune(self.model, node, self.graph, self.nx_graph) node_cls.output_prune(self.model, node, self.graph, self.nx_graph) pruned_node_modules.append(node_module) nncf_logger.info('Finished mask applying step')
def run_batchnorm_adaptation(self, config): initializer_params = config.get("initializer", {}) init_bn_adapt_config = initializer_params.get('batchnorm_adaptation', {}) num_bn_adaptation_samples = init_bn_adapt_config.get('num_bn_adaptation_samples', 0) num_bn_forget_samples = init_bn_adapt_config.get('num_bn_forget_samples', 0) try: bn_adaptation_args = config.get_extra_struct(BNAdaptationInitArgs) has_bn_adapt_init_args = True except KeyError: has_bn_adapt_init_args = False if not init_bn_adapt_config: if has_bn_adapt_init_args: nncf_logger.warning("Enabling quantization batch norm adaptation with default parameters.") num_bn_adaptation_samples = 2000 num_bn_forget_samples = 1000 if num_bn_adaptation_samples < 0: raise AttributeError('Number of adaptation samples must be >= 0') if num_bn_adaptation_samples > 0: if not has_bn_adapt_init_args: nncf_logger.info( 'Could not run batchnorm adaptation ' 'as the adaptation data loader is not provided as an extra struct. ' 'Refer to `NNCFConfig.register_extra_structs` and the `BNAdaptationInitArgs` class') return batch_size = bn_adaptation_args.data_loader.batch_size num_bn_forget_steps = numpy.ceil(num_bn_forget_samples / batch_size) num_bn_adaptation_steps = numpy.ceil(num_bn_adaptation_samples / batch_size) bn_adaptation_runner = DataLoaderBNAdaptationRunner(self._model, bn_adaptation_args.device, num_bn_forget_steps) bn_adaptation_runner.run(bn_adaptation_args.data_loader, num_bn_adaptation_steps)
def propagate_can_prune_attr_down(self): """ Propagating can_prune attribute down to fix all branching cases with one pruned and one not pruned branches. """ sorted_nodes = [ self.nx_graph.nodes[name] for name in nx.topological_sort(self.nx_graph) ] for node in sorted_nodes: # Propagate attribute only in not conv case if self.node_propagate_can_prune_attr(node['key']): in_edges = self.nx_graph.in_edges(node['key']) can_prune = all( self.nx_graph.nodes[key][ModelPruner.CAN_PRUNE_ATTR] for key, _ in in_edges) can_prune_any = any( self.nx_graph.nodes[key][ModelPruner.CAN_PRUNE_ATTR] for key, _ in in_edges) if (not self.node_accept_different_inputs(node) and not can_prune) or \ (self.node_accept_different_inputs(node) and not can_prune_any): node[ModelPruner.CAN_PRUNE_ATTR] = can_prune nncf_logger.info('Propagated can_prune attribute down')
def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): input_mask = nx_node['input_masks'][0] if input_mask is None: return nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) bool_mask = torch.tensor(input_mask, dtype=torch.bool) old_num_clannels = int(node_module.weight.size(0)) new_num_channels = int(torch.sum(input_mask)) node_module.num_features = new_num_channels node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask]) node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask]) node_module.running_mean = torch.nn.Parameter( node_module.running_mean[bool_mask], requires_grad=False) node_module.running_var = torch.nn.Parameter( node_module.running_var[bool_mask], requires_grad=False) nncf_logger.info( 'Pruned BatchNorm {} by input mask. Old num features: {}, new num features:' ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))
def run_batchnorm_adaptation(self, config): initializer_params = config.get("initializer", {}) init_bn_adapt_config = initializer_params.get('batchnorm_adaptation', {}) num_bn_adaptation_steps = init_bn_adapt_config.get( 'num_bn_adaptation_steps', 0) num_bn_forget_steps = init_bn_adapt_config.get('num_bn_forget_steps', 5) if num_bn_adaptation_steps < 0: raise AttributeError( 'Number of batch adaptation steps must be >= 0') if num_bn_adaptation_steps > 0: try: bn_adaptation_args = config.get_extra_struct( BNAdaptationInitArgs) except KeyError: nncf_logger.info( 'Could not run batchnorm adaptation ' 'as the adaptation data loader is not provided as an extra struct. ' 'Refer to `NNCFConfig.register_extra_structs` and the `BNAdaptationInitArgs` class' ) return bn_adaptation_runner = DataLoaderBNAdaptationRunner( self._model, bn_adaptation_args.device, num_bn_forget_steps) bn_adaptation_runner.run(bn_adaptation_args.data_loader, num_bn_adaptation_steps)
def apply_mask(self): """ Applying propagated masks for all nodes in topological order: if all inputs of node can_prune -> running input_prune method for this node if node[ModelPruner.CAN_PRUNE_ATTR] -> running output_prune method for this node """ sorted_nodes = [ self.nx_graph.nodes[name] for name in nx.topological_sort(self.nx_graph) ] with torch.no_grad(): for node in sorted_nodes: node_type = self.graph.node_type_fn(node) node_cls = self.get_class_by_type_name(node_type)() in_edges = self.nx_graph.in_edges(node['key']) can_prune_input = all( self.nx_graph.nodes[key][ModelPruner.CAN_PRUNE_ATTR] for key, _ in in_edges) if can_prune_input: node_cls.input_prune(self.model, node, self.graph, self.nx_graph) if node[ModelPruner.CAN_PRUNE_ATTR]: node_cls.output_prune(self.model, node, self.graph, self.nx_graph) nncf_logger.info('Finished mask applying step')
def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): input_mask = nx_node['input_masks'][0] if input_mask is None: return bool_mask = torch.tensor(input_mask, dtype=torch.bool) new_num_channels = int(torch.sum(input_mask)) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) is_depthwise = nx_node['is_depthwise'] old_num_clannels = int(node_module.weight.size(1)) if is_depthwise: # In depthwise case prune output channels by input mask, here only fix for new number of input channels node_module.groups = new_num_channels node_module.in_channels = new_num_channels else: out_channels = node_module.weight.size(0) broadcasted_mask = bool_mask.repeat(out_channels).view( out_channels, bool_mask.size(0)) new_weight_shape = list(node_module.weight.shape) new_weight_shape[1] = new_num_channels node_module.in_channels = new_num_channels node_module.weight = torch.nn.Parameter( node_module.weight[broadcasted_mask].view(new_weight_shape)) nncf_logger.info( 'Pruned Convolution {} by input mask. Old input filters number: {}, new filters number:' ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))
def apply_minmax_init(self, min_values, max_values, log_module_name: str = None): if self.initialized: nncf_logger.debug( "Skipped initializing {} - loaded from checkpoint".format( log_module_name)) return if torch.any(torch.eq(min_values, np.inf)) or torch.any( torch.eq(max_values, -np.inf)): raise AttributeError( 'Statistics is not collected for {}'.format(log_module_name)) sign = torch.any(torch.lt(min_values, 0)) if self.signedness_to_force is not None and sign != self.signedness_to_force: nncf_logger.warning("Forcing signed to {} for module {}".format( self.signedness_to_force, log_module_name)) sign = self.signedness_to_force self.signed = int(sign) abs_max = torch.max(torch.abs(max_values), torch.abs(min_values)) SCALE_LOWER_THRESHOLD = 0.1 self.scale.fill_(SCALE_LOWER_THRESHOLD) self.scale.masked_scatter_(torch.gt(abs_max, SCALE_LOWER_THRESHOLD), abs_max) nncf_logger.info("Set sign: {} and scale: {} for {}".format( self.signed, get_flat_tensor_contents_string(self.scale), log_module_name))
def dump_metric(self, configuration_metric: List[Tensor]): import matplotlib.pyplot as plt list_to_plot = [cm.item() for cm in configuration_metric] fig = plt.figure() fig.suptitle('Pareto Frontier') ax = fig.add_subplot(2, 1, 1) ax.set_yscale('log') ax.set_xlabel('Model Size (MB)') ax.set_ylabel('Metric value (total perturbation)') ax.scatter(self._model_sizes, list_to_plot, s=20, facecolors='none', edgecolors='r') cm = torch.Tensor(configuration_metric) cm_m = cm.median().item() configuration_index = configuration_metric.index(cm_m) ms_m = self._model_sizes[configuration_index] ax.scatter(ms_m, cm_m, s=30, facecolors='none', edgecolors='b', label='median from all metrics') ax.legend() plt.savefig(os.path.join(self._dump_dir, 'Pareto_Frontier')) nncf_logger.info( 'Distribution of HAWQ metrics: min_value={:.3f}, max_value={:.3f}, median_value={:.3f}, ' 'median_index={}, total_number={}'.format( cm.min().item(), cm.max().item(), cm_m, configuration_index, len(configuration_metric)))
def apply_minmax_init(self, min_values, max_values, log_module_name: str = None): if self.initialized: nncf_logger.debug( "Skipped initializing {} - loaded from checkpoint".format( log_module_name)) return if torch.any(torch.eq(min_values, np.inf)) or torch.any( torch.eq(max_values, -np.inf)): raise AttributeError( 'Statistics is not collected for {}'.format(log_module_name)) ranges = max_values - min_values max_range = torch.max(max_values - min_values) eps = 1e-2 correction = (clamp(ranges, low=eps * max_range, high=max_range) - ranges) * 0.5 self.input_range.data = (ranges + 2 * correction).data self.input_low.data = (min_values - correction).data nncf_logger.info("Set input_low: {} and input_range: {} for {}".format( get_flat_tensor_contents_string(self.input_low), get_flat_tensor_contents_string(self.input_range), log_module_name))
def propagate_can_prune_attr_up(self): """ Propagating can_prune attribute in reversed topological order. This attribute depends on accept_pruned_input and can_prune attributes of output nodes. Node can_prune is True if all outputs accept_pruned_input is True and all outputs (except convs because conv can be pruned by input and output independently). """ for node_name in self.nx_graph.nodes: self.nx_graph.nodes[node_name][ModelPruner.CAN_PRUNE_ATTR] = True reversed_sorted_nodes = reversed([ self.nx_graph.nodes[name] for name in nx.topological_sort(self.nx_graph) ]) for node in reversed_sorted_nodes: # Check all output nodes accept_pruned_input attribute out_edges = self.nx_graph.out_edges(node['key']) outputs_accept_pruned_input = all( self.nx_graph.nodes[key]['accept_pruned_input'] for _, key in out_edges) # Check all output nodes can_prune attribute outputs_will_be_pruned = all([ self.nx_graph.nodes[key][ModelPruner.CAN_PRUNE_ATTR] for _, key in out_edges if self.node_propagate_can_prune_attr(key) ]) node[ ModelPruner. CAN_PRUNE_ATTR] = outputs_accept_pruned_input and outputs_will_be_pruned nncf_logger.info('Propagated can_prune attribute up')
def _compute_and_display_flops_binarization_rate(self): net = self._model weight_list = {} state_dict = net.state_dict() for n, v in state_dict.items(): weight_list[n] = v.clone() ops_dict = OrderedDict() def get_hook(name): def compute_flops_hook(self, input_, output): name_type = str(type(self).__name__) if isinstance(self, (nn.Conv2d, nn.ConvTranspose2d)): ks = self.weight.data.shape ops_count = ks[0] * ks[1] * ks[2] * ks[3] * output.shape[3] * output.shape[2] elif isinstance(self, nn.Linear): ops_count = input_[0].shape[1] * output.shape[1] else: return ops_dict[name] = (name_type, ops_count, isinstance(self, NNCFConv2d)) return compute_flops_hook hook_list = [m.register_forward_hook(get_hook(n)) for n, m in net.named_modules()] net.do_dummy_forward(force_eval=True) for h in hook_list: h.remove() # restore all parameters that can be corrupted due forward pass for n, v in state_dict.items(): state_dict[n].data.copy_(weight_list[n].data) ops_bin = 0 ops_total = 0 for layer_name, (layer_type, ops, is_binarized) in ops_dict.items(): ops_total += ops if is_binarized: ops_bin += ops table = Texttable() header = ["Layer name", "Layer type", "Binarized", "MAC count", "MAC share"] table_data = [header] for layer_name, (layer_type, ops, is_binarized) in ops_dict.items(): drow = {h: 0 for h in header} drow["Layer name"] = layer_name drow["Layer type"] = layer_type drow["Binarized"] = 'Y' if is_binarized else 'N' drow["MAC count"] = "{:.3f}G".format(ops * 1e-9) drow["MAC share"] = "{:2.1f}%".format(ops / ops_total * 100) row = [drow[h] for h in header] table_data.append(row) table.add_rows(table_data) nncf_logger.info(table.draw()) nncf_logger.info("Total binarized MAC share: {:.1f}%".format(ops_bin / ops_total * 100))
def _run_quantization_pipeline(self, finetune=False) -> float: self.qctrl.run_batchnorm_adaptation(self.qctrl.quantization_config) if finetune: raise NotImplementedError( "Post-Quantization fine tuning is not implemented.") with torch.no_grad(): quantized_score = self.eval_fn(self.qmodel, self.eval_loader) logger.info( "[Q.Env] Quantized Score: {:.3f}".format(quantized_score)) return quantized_score
def prune_model(self): """ Model pruner work in two stages: 1. Mask propagation: propagate pruning masks through the graph. 2. Applying calculated masks :return: """ nncf_logger.info('Start pruning model') self.mask_propagation() self.apply_mask() nncf_logger.info('Finished pruning model')
def replace_modules(model: nn.Module, replace_fn, affected_scopes, ignored_scopes=None, target_scopes=None, memo=None, current_scope=None): if memo is None: memo = set() current_scope = Scope() current_scope.push(ScopeElement(model.__class__.__name__)) if model in memo: return model, affected_scopes memo.add(model) for name, module in model.named_children(): if module is None: continue child_scope_element = ScopeElement(module.__class__.__name__, name) child_scope = current_scope.copy() child_scope.push(child_scope_element) replaced_module = replace_fn(module) if replaced_module is not None: replaced_scope_element = ScopeElement( replaced_module.__class__.__name__, name) replaced_scope = current_scope.copy() replaced_scope.push(replaced_scope_element) if module is not replaced_module: if in_scope_list(str(child_scope), ignored_scopes): nncf_logger.info( "Ignored wrapping modules in scope: {}".format( child_scope)) continue if target_scopes is None or in_scope_list( str(child_scope), target_scopes): nncf_logger.info("Wrapping module {} by {}".format( str(child_scope), str(replaced_scope))) if isinstance(model, nn.Sequential): # pylint: disable=protected-access model._modules[name] = replaced_module else: setattr(model, name, replaced_module) affected_scopes.append(replaced_scope) elif is_nncf_module(replaced_module): # Got an NNCF-wrapped module from previous compression stage, track its scope as well affected_scopes.append(replaced_scope) _, affected_scopes = replace_modules(module, replace_fn, affected_scopes, ignored_scopes, target_scopes, memo, child_scope) return model, affected_scopes
def apply_init(self): original_device = next(self._model.parameters()).device self._model.to(self._init_device) traces_per_layer = self._calc_traces(self._criterion, self._iter_number, self._tolerance) if not traces_per_layer: raise RuntimeError('Failed to calculate hessian traces!') num_weights = len(self._ordered_weight_quantizations) bits_configurations = self.get_configs_constrained_by_order( self._bits, num_weights) ordered_weight_quantization_ids = list( self._ordered_weight_quantizations.keys()) bits_configurations = self._filter_configs_by_precision_constraints( bits_configurations, self._hw_precision_constraints, ordered_weight_quantization_ids, traces_per_layer.get_order_of_traces()) if not bits_configurations: raise RuntimeError( 'All bits configurations are incompatible with HW Config!') perturbations, weight_observers = self.calc_quantization_noise() configuration_metric = self.calc_hawq_metric_per_configuration( bits_configurations, perturbations, traces_per_layer, self._init_device) chosen_config_per_layer = self.choose_configuration( configuration_metric, bits_configurations, traces_per_layer.get_order_of_traces()) self.set_chosen_config(chosen_config_per_layer) ordered_metric_per_layer = self.get_metric_per_layer( chosen_config_per_layer, perturbations, traces_per_layer) if is_debug(): hawq_debugger = HAWQDebugger(bits_configurations, perturbations, weight_observers, traces_per_layer, self._bits) hawq_debugger.dump_metric(configuration_metric) hawq_debugger.dump_avg_traces() hawq_debugger.dump_density_of_quantization_noise() hawq_debugger.dump_perturbations_ratio() hawq_debugger.dump_bitwidth_graph(self._algo, self._model) self._model.rebuild_graph() str_bw = [str(element) for element in self.get_bitwidth_per_scope()] nncf_logger.info('\n'.join( ['\n\"bitwidth_per_scope\": [', ',\n'.join(str_bw), ']'])) self._model.to(original_device) return ordered_metric_per_layer
def _evaluate_pretrained_model(self): logger.info("[Q.Env] Evaluating Pretrained Model") self.qctrl.disable_weight_quantization() self.qctrl.disable_activation_quantization() with torch.no_grad(): self.pretrained_score = self.eval_fn(self.qmodel, self.eval_loader) logger.info("Pretrained Score: {:.3f}".format( self.pretrained_score)) self.qctrl.enable_weight_quantization() self.qctrl.enable_activation_quantization() self.qmodel.rebuild_graph()
def choose_configuration(self, configuration_metric: List[Tensor], bits_configurations: List[List[int]], traces_order: List[int]) -> List[int]: num_weights = len(traces_order) ordered_config = [0] * num_weights median_metric = torch.Tensor(configuration_metric).to(self._device).median() configuration_index = configuration_metric.index(median_metric) bit_configuration = bits_configurations[configuration_index] for i, bitwidth in enumerate(bit_configuration): ordered_config[traces_order[i]] = bitwidth if is_main_process(): nncf_logger.info('Chosen HAWQ configuration (bitwidth per weightable layer)={}'.format(ordered_config)) nncf_logger.debug('Order of the weightable layers in the HAWQ configuration={}'.format(traces_order)) return ordered_config
def apply_init(self): disabled_gradients = self.disable_quantizer_gradients( self._all_quantizers_per_scope, self._algo.quantized_weight_modules_registry, self._model) traces_per_layer = self._calc_traces(self._criterion, self._iter_number, self._tolerance) if not traces_per_layer: raise RuntimeError('Failed to calculate hessian traces!') self.enable_quantizer_gradients(self._model, self._all_quantizers_per_scope, disabled_gradients) num_weights = len(self._ordered_weight_quantizations) bits_configurations = self.get_configs_constrained_by_order( self._bits, num_weights) ordered_weight_quantization_ids = list( self._ordered_weight_quantizations.keys()) bits_configurations = self.filter_configs_by_precision_constraints( bits_configurations, self._hw_precision_constraints, ordered_weight_quantization_ids, traces_per_layer.get_order_of_traces()) if not bits_configurations: raise RuntimeError( 'All bits configurations are incompatible with HW Config!') perturbations, weight_observers = self.calc_quantization_noise() configuration_metric = self.calc_hawq_metric_per_configuration( bits_configurations, perturbations, traces_per_layer, self._device) chosen_config_per_layer = self.choose_configuration( configuration_metric, bits_configurations, traces_per_layer.get_order_of_traces()) self.set_chosen_config(chosen_config_per_layer) ordered_metric_per_layer = self.get_metric_per_layer( chosen_config_per_layer, perturbations, traces_per_layer) if is_debug(): self.HAWQDump(bits_configurations, configuration_metric, perturbations, weight_observers, traces_per_layer, self._bits).run() self._model.rebuild_graph() str_bw = [str(element) for element in self.get_bitwidth_per_scope()] nncf_logger.info('\n'.join( ['\n\"bitwidth_per_scope\": [', ',\n'.join(str_bw), ']'])) return ordered_metric_per_layer
def binarize(self, x): if self.training and not self.is_scale_initialized: # init scale using nonbinarized activation statistics d = x.detach().data.contiguous().view(-1) top_num = max(1, round(d.shape[0]*0.001)) topk_res = d.topk(top_num) scale = topk_res[0].min() nncf_logger.info("Binarized activation scale set to: {}".format(scale.item())) self.scale.data[:] = scale.log() self.is_scale_initialized = True x = self.bin(x, self.scale.exp(), self.threshold.sigmoid()) return x
def mask_propagation(self): """ Mask propagation in graph: to propagate masks run method mask_propagation (of metaop of current node) on all nodes in topological order. """ sorted_nodes = [ self.nx_graph.nodes[node_name] for node_name in nx.topological_sort(self.nx_graph) ] for node in sorted_nodes: node_type = self.graph.node_type_fn(node) cls = self.get_class_by_type_name(node_type)() cls.mask_propagation(self.model, node, self.graph, self.nx_graph) nncf_logger.info('Finished mask propagation in graph')
def apply_init(self): runner = HessianAwarePrecisionInitializeRunner(self._algo, self._model, self._data_loader, self._num_data_points, self._all_quantizations, self._ordered_weight_quantizations, self._bits, self._traces_per_layer_path) runner.run(self._criterion, self._iter_number, self._tolerance) self._model.rebuild_graph() if self._is_distributed: # NOTE: Order of quantization modules must be the same on GPUs to correctly broadcast num_bits sorted_quantizers = OrderedDict(sorted(self._all_quantizations.items(), key=lambda x: str(x[0]))) for quantizer in sorted_quantizers.values(): # type: BaseQuantizer quantizer.broadcast_num_bits() if is_main_process(): str_bw = [str(element) for element in self.get_bitwidth_per_scope(sorted_quantizers)] nncf_logger.info('\n'.join(['\n\"bitwidth_per_scope\": [', ',\n'.join(str_bw), ']']))
def prune_model(self): """ Model pruner work in three stages: 1. Mask propagation: propagate pruning masks through the graph. 2. Propagate can_prune attribute (up and then down) through the graph.This attribute shows can we really prune some node or not. 3. Applying masks accordingly with can_prune attribute (only when can prune). :return: """ nncf_logger.info('Start pruning model') self.mask_propagation() self.propagate_can_prune_attr_up() self.propagate_can_prune_attr_down() self.apply_mask() nncf_logger.info('Finished pruning model')
def _apply_minmax_init(self, min_values, max_values, log_module_name: str = None): ranges = max_values - min_values max_range = torch.max(max_values - min_values) eps = 1e-2 correction = (clamp(ranges, low=eps * max_range, high=max_range) - ranges) * 0.5 self._input_range_param_storage.data = (ranges + 2 * correction).data if self._is_using_log_scale_storage: self._input_range_param_storage.data.log_() self.input_low.data = (min_values - correction).data nncf_logger.info("Set input_low: {} and input_range: {} for {}".format( get_flat_tensor_contents_string(self.input_low), get_flat_tensor_contents_string(self.input_range), log_module_name))
def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): input_mask = nx_node['input_masks'][0] if input_mask is None: return bool_mask = torch.tensor(input_mask, dtype=torch.bool) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) old_num_clannels = int(node_module.weight.size(0)) node_module.in_channels = int(torch.sum(bool_mask)) node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask]) nncf_logger.info( 'Pruned ConvTranspose {} by input mask. Old input filters number: {}, new filters number:' ' {}.'.format(nx_node['key'], old_num_clannels, node_module.in_channels))