def _prepare_deployable_graph(self, module, input_args, device, output_dir): module, input_args = to_device(module, input_args, device) quant_module, graph = prepare_quantizable_module( module=module, input_args=input_args, export_folder=output_dir, device=device) set_option_value("nndct_quant_off", True) register_output_hook(quant_module, record_once=True) set_outputs_recorder_status(quant_module, True) if isinstance(input_args, tuple): _ = quant_module(*input_args) else: _ = quant_module(input_args) g_optmizer = DevGraphOptimizer(graph) connect_module_with_graph(quant_module, g_optmizer.dev_graph, recover_param=False) update_nndct_blob_data(quant_module, g_optmizer.dev_graph) g_optmizer.strip_redundant_ops() g_optmizer.update_op_attrs() g_optmizer.constant_folding() g_optmizer.layout_tranform() connect_module_with_graph(quant_module, g_optmizer.dev_graph, recover_param=False) update_nndct_blob_data(quant_module, g_optmizer.dev_graph) connect_module_with_graph(quant_module, graph, recover_param=False) return g_optmizer.dev_graph
def __exit__(self, *args): set_option_value("nndct_param_corr", self._param_corr) for node in self._processor.graph.nodes: for _, fp_history in self._processor.quantizer.fp_history.items(): if node.name in fp_history: fp_history[node.name].clear() for mod in self._processor.quant_model.modules(): if hasattr(mod, "param_quantized"): setattr(mod, "param_quantized", False) for mod in self._processor.quant_model.modules(): if hasattr(mod, "param_saved"): setattr(mod, "param_saved", False) self._processor.setup_calib() # don't change tensors' quantization step in re-calibration after fast finetune self._processor.set_keep_fp(True)
def finetune(self, run_fn, run_args): if self.quantizer.quant_mode == 2: NndctScreenLogger().warning( f"Finetune function will be ignored in test mode!") return NndctScreenLogger().info( f"=>Finetuning module parameters for better quantization accuracy... " ) # backup option value opt_bak_param_corr = NndctOption.nndct_param_corr.value set_option_value("nndct_param_corr", 0) # cache input and output #print("**** cache input and output") last_quant_nodes = self.collect_last_quant_nodes() with torch.no_grad(): hook_mods = [] for node in self.graph.nodes: if node.op.type == NNDCT_OP.INPUT or \ node in last_quant_nodes: # (self.quantizer.configer.is_node_quantizable(node, False) and # len(node.op.params) > 0): hook_mods.append(node.module) handlers = self.hook_cache_output(hook_mods) set_option_value("nndct_quant_off", True) run_fn(*run_args) self.clean_hooks(handlers) # for mod in self.quant_model.modules(): # if hasattr(mod, "node") and mod.node.op.type in [NNDCT_OP.DENSE, NNDCT_OP.CONV2D, NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.CONVTRANSPOSE2D]: # self._float_weights[mod.node].append(mod.weight.detach().cpu()) torch.cuda.empty_cache() # calibration to get a set of quantization steps #print("****calibration to get float model tensor values") for mod in self.quant_model.modules(): if hasattr(mod, "param_quantized"): setattr(mod, "param_quantized", False) # evaluation to get float model tensors set_option_value("nndct_quant_off", False) with torch.no_grad(): run_fn(*run_args) torch.cuda.empty_cache() #print("****Parameter finetuning") device = GLOBAL_MAP.get_ele(NNDCT_KEYS.QUANT_DEVICE) graph_searcher = GraphSearcher(self.graph) node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU6]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU6]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.RELU]) ]) layer_act_group = {} for _, node_list in node_sets.items(): for nodeset in node_list: conv, act = nodeset layer_act_group[conv] = act # to avoid quantization steps change among parameter finetuning self.quantizer.quant_mode = 2 net_inputs = [] for node in self.input_nodes: cached_net_input = [ out for out in self.cached_outputs[node.module] ] net_inputs.append(cached_net_input) # last_quant_nodes = self.collect_last_quant_nodes() last_quant_mods = [node.module for node in last_quant_nodes] handlers = self.hook_cache_output(last_quant_mods, hook_type="single") net_loss = self.eval_loss(net_inputs, last_quant_mods, device) self.clean_hooks(handlers) # model.clean_hooks() torch.cuda.empty_cache() finetune_group = {} # hook_mods = [] for qmod, fmod in zip(self._quant_model.modules(), self._float_model.modules()): if hasattr(qmod, "node"): if (self.quantizer.configer.is_node_quantizable( qmod.node, False) and len(qmod.node.op.params) > 0): finetune_group[qmod.node] = [qmod, fmod] # hook_mods.append(fmod) # self.hook_cache_output(hook_mods, hook_type="single") for node, module_pair in finetune_group.items(): # if self.quantizer.configer.is_node_quantizable(node, False) and \ # len(node.op.params) > 0: quant_layer, float_layer = module_pair pn_node = self.graph.parents(node)[0] handlers = self.hook_cache_output([pn_node.module], hook_type="single") layer_inputs = [] with torch.no_grad(): for input_args in zip(*net_inputs): new_input_args = [] for ip in input_args: if isinstance(ip, torch.Tensor): new_input_args.append(ip.to(device)) _ = self.quant_model(*new_input_args) layer_inputs.append( self.cached_output[pn_node.module].detach().cpu()) self.clean_hooks(handlers) del self.cached_output[pn_node.module] #print(f"Tuning {node.name}") net_loss = self.optimize_layer(node, float_layer, layer_inputs, layer_act_group, net_inputs, net_loss, last_quant_mods, device) del layer_inputs torch.cuda.empty_cache() # recover quantizer status for node in self.graph.nodes: for _, fp_history in self.quantizer.fp_history.items(): if node.name in fp_history: fp_history[node.name].clear() for mod in self.quant_model.modules(): if hasattr(mod, "param_quantized"): setattr(mod, "param_quantized", False) for mod in self.quant_model.modules(): if hasattr(mod, "param_saved"): setattr(mod, "param_saved", False) self.quantizer.quant_mode = 1 set_option_value("nndct_param_corr", opt_bak_param_corr) # export finetuned parameters self.quantizer.export_param()
def finetune(self, run_fn, run_args): if self.quantizer.quant_mode == 2: NndctScreenLogger().warning( f"Finetune function will be ignored in test mode!") return NndctScreenLogger().info( f"=>Preparing data for fast finetuning module parameters ...") # memory status total_m, *_, available_m = list( map(lambda x: x / 1024, map(int, os.popen('free -t -m').readlines()[1].split()[1:]))) NndctScreenLogger().info( f"Mem status(total mem: {total_m:.2f}G, available mem: {available_m:.2f}G)." ) NndctScreenLogger().info( f"=>Preparing data for fast finetuning module parameters ...") # backup option value opt_bak_param_corr = NndctOption.nndct_param_corr.value set_option_value("nndct_param_corr", 0) # cache input and output #print("**** cache input and output") last_quant_nodes = self.collect_last_quant_nodes() with torch.no_grad(): cache_layers = [] monitor_layers = [] for node in self.graph.nodes: if node.op.type == NNDCT_OP.INPUT or node in last_quant_nodes: cache_layers.append(node.module) elif self.quantizer.configer.is_conv_like(node): monitor_layers.append(node.module) monitor_handlers = self.hook_memory_monitor(monitor_layers) cache_handlers = self.hook_cache_output(cache_layers, monitor_mem=True) set_option_value("nndct_quant_off", True) run_fn(*run_args) # memory statistics total_memory_cost = 0.0 for layer in cache_layers: total_memory_cost += self._mem_count[layer] del self._mem_count[layer] total_memory_cost += 2 * max(self._mem_count.values()) self.clean_hooks(monitor_handlers + cache_handlers) torch.cuda.empty_cache() NndctScreenLogger().info( f"Mem cost by fast finetuning: {total_memory_cost:.2f}G.") if total_memory_cost > 0.8 * available_m: NndctScreenLogger().warning( f"There is not enought memory for fast finetuning and this process will be ignored!.Try to use a smaller calibration dataset." ) return # calibration to get a set of quantization steps #print("****calibration to get float model tensor values") for mod in self.quant_model.modules(): if hasattr(mod, "param_quantized"): setattr(mod, "param_quantized", False) # evaluation to get float model tensors set_option_value("nndct_quant_off", False) with torch.no_grad(): run_fn(*run_args) torch.cuda.empty_cache() #print("****Parameter finetuning") NndctScreenLogger().info( f"=>Fast finetuning module parameters for better quantization accuracy..." ) device = GLOBAL_MAP.get_ele(NNDCT_KEYS.QUANT_DEVICE) graph_searcher = GraphSearcher(self.graph) node_sets = graph_searcher.find_nodes_from_type([ PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.HSWISH]), PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.HSIGMOID]), PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU6]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.HSWISH]), PatternType( pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.HSIGMOID]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU6]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.HSWISH]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.HSIGMOID]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.RELU6]), PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.HSWISH]), PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.HSIGMOID]), PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.RELU6]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.HSWISH]), PatternType( pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.HSIGMOID]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.RELU6]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.HSWISH]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.HSIGMOID]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.RELU]), PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.RELU6]), ]) layer_act_group = {} for _, node_list in node_sets.items(): for nodeset in node_list: conv, act = nodeset layer_act_group[conv] = act # to avoid quantization steps change among parameter finetuning self.quantizer.quant_mode = 2 net_inputs = [] for node in self.input_nodes: cached_net_input = [ out for out in self.cached_outputs[node.module] ] net_inputs.append(cached_net_input) # last_quant_nodes = self.collect_last_quant_nodes() last_quant_mods = [node.module for node in last_quant_nodes] handlers = self.hook_cache_output(last_quant_mods, hook_type="single") net_loss = self.eval_loss(net_inputs, last_quant_mods, device) self.clean_hooks(handlers) # model.clean_hooks() torch.cuda.empty_cache() finetune_group = {} # hook_mods = [] for qmod, fmod in zip(self._quant_model.modules(), self._float_model.modules()): if hasattr(qmod, "node"): if (self.quantizer.configer.is_node_quantizable( qmod.node, False) and len(qmod.node.op.params) > 0): finetune_group[qmod.node] = [qmod, fmod] # hook_mods.append(fmod) # self.hook_cache_output(hook_mods, hook_type="single") #for node, module_pair in finetune_group.items(): for idx, (node, module_pair) in tqdm(enumerate(finetune_group.items()), total=len(finetune_group.items())): # if self.quantizer.configer.is_node_quantizable(node, False) and \ # len(node.op.params) > 0: quant_layer, float_layer = module_pair pn_node = self.graph.parents(node)[0] handlers = self.hook_cache_output([pn_node.module], hook_type="single") layer_inputs = [] with torch.no_grad(): for input_args in zip(*net_inputs): new_input_args = [] for ip in input_args: if isinstance(ip, torch.Tensor): new_input_args.append(ip.to(device)) _ = self.quant_model(*new_input_args) layer_inputs.append( self.cached_output[pn_node.module].detach().cpu()) self.clean_hooks(handlers) del self.cached_output[pn_node.module] #print(f"Tuning {node.name}") net_loss = self.optimize_layer(node, float_layer, layer_inputs, layer_act_group, net_inputs, net_loss, last_quant_mods, device) # print(f"{node.name}:{net_loss}") del layer_inputs torch.cuda.empty_cache() # recover quantizer status for node in self.graph.nodes: for _, config_history in self.quantizer.config_history.items(): if node.name in config_history: config_history[node.name].clear() for mod in self.quant_model.modules(): if hasattr(mod, "param_quantized"): setattr(mod, "param_quantized", False) for mod in self.quant_model.modules(): if hasattr(mod, "param_saved"): setattr(mod, "param_saved", False) self.quantizer.quant_mode = 1 set_option_value("nndct_param_corr", opt_bak_param_corr) NndctScreenLogger().info(f"=>Export fast finetuned parameters ...") # export finetuned parameters self.quantizer.export_param()
def __enter__(self): self._param_corr = NndctOption.nndct_param_corr.value set_option_value("nndct_param_corr", 0)
def __exit__(self, *args): set_option_value("nndct_quant_off", False)
def __enter__(self): set_option_value("nndct_quant_off", True)