Beispiel #1
0
 def _prepare_deployable_graph(self, module, input_args, device,
                               output_dir):
     module, input_args = to_device(module, input_args, device)
     quant_module, graph = prepare_quantizable_module(
         module=module,
         input_args=input_args,
         export_folder=output_dir,
         device=device)
     set_option_value("nndct_quant_off", True)
     register_output_hook(quant_module, record_once=True)
     set_outputs_recorder_status(quant_module, True)
     if isinstance(input_args, tuple):
         _ = quant_module(*input_args)
     else:
         _ = quant_module(input_args)
     g_optmizer = DevGraphOptimizer(graph)
     connect_module_with_graph(quant_module,
                               g_optmizer.dev_graph,
                               recover_param=False)
     update_nndct_blob_data(quant_module, g_optmizer.dev_graph)
     g_optmizer.strip_redundant_ops()
     g_optmizer.update_op_attrs()
     g_optmizer.constant_folding()
     g_optmizer.layout_tranform()
     connect_module_with_graph(quant_module,
                               g_optmizer.dev_graph,
                               recover_param=False)
     update_nndct_blob_data(quant_module, g_optmizer.dev_graph)
     connect_module_with_graph(quant_module, graph, recover_param=False)
     return g_optmizer.dev_graph
Beispiel #2
0
 def __exit__(self, *args):
   set_option_value("nndct_param_corr", self._param_corr)
   for node in self._processor.graph.nodes:
     for _, fp_history in self._processor.quantizer.fp_history.items():
       if node.name in fp_history:
         fp_history[node.name].clear()
   for mod in self._processor.quant_model.modules():
     if hasattr(mod, "param_quantized"):
       setattr(mod, "param_quantized", False)
   for mod in self._processor.quant_model.modules():
     if hasattr(mod, "param_saved"):
       setattr(mod, "param_saved", False)
       
   self._processor.setup_calib()
   # don't change tensors' quantization step in re-calibration after fast finetune
   self._processor.set_keep_fp(True)
Beispiel #3
0
    def finetune(self, run_fn, run_args):
        if self.quantizer.quant_mode == 2:
            NndctScreenLogger().warning(
                f"Finetune function will be ignored in test mode!")
            return
        NndctScreenLogger().info(
            f"=>Finetuning module parameters for better quantization accuracy... "
        )

        # backup option value
        opt_bak_param_corr = NndctOption.nndct_param_corr.value
        set_option_value("nndct_param_corr", 0)

        # cache input and output
        #print("**** cache input and output")
        last_quant_nodes = self.collect_last_quant_nodes()
        with torch.no_grad():
            hook_mods = []
            for node in self.graph.nodes:
                if node.op.type == NNDCT_OP.INPUT or \
                node in last_quant_nodes:
                    # (self.quantizer.configer.is_node_quantizable(node, False) and
                    # len(node.op.params) > 0):
                    hook_mods.append(node.module)

            handlers = self.hook_cache_output(hook_mods)

            set_option_value("nndct_quant_off", True)
            run_fn(*run_args)
            self.clean_hooks(handlers)

            # for mod in self.quant_model.modules():
            #   if hasattr(mod, "node") and mod.node.op.type in [NNDCT_OP.DENSE, NNDCT_OP.CONV2D, NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.CONVTRANSPOSE2D]:
            #     self._float_weights[mod.node].append(mod.weight.detach().cpu())

        torch.cuda.empty_cache()

        # calibration to get a set of quantization steps
        #print("****calibration to get float model tensor values")
        for mod in self.quant_model.modules():
            if hasattr(mod, "param_quantized"):
                setattr(mod, "param_quantized", False)

        # evaluation to get float model tensors
        set_option_value("nndct_quant_off", False)
        with torch.no_grad():
            run_fn(*run_args)
        torch.cuda.empty_cache()

        #print("****Parameter finetuning")
        device = GLOBAL_MAP.get_ele(NNDCT_KEYS.QUANT_DEVICE)
        graph_searcher = GraphSearcher(self.graph)
        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU6]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU6]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.RELU])
        ])

        layer_act_group = {}
        for _, node_list in node_sets.items():
            for nodeset in node_list:
                conv, act = nodeset
                layer_act_group[conv] = act

        # to avoid quantization steps change among parameter finetuning
        self.quantizer.quant_mode = 2

        net_inputs = []
        for node in self.input_nodes:
            cached_net_input = [
                out for out in self.cached_outputs[node.module]
            ]
            net_inputs.append(cached_net_input)

        # last_quant_nodes = self.collect_last_quant_nodes()
        last_quant_mods = [node.module for node in last_quant_nodes]

        handlers = self.hook_cache_output(last_quant_mods, hook_type="single")
        net_loss = self.eval_loss(net_inputs, last_quant_mods, device)
        self.clean_hooks(handlers)
        # model.clean_hooks()
        torch.cuda.empty_cache()

        finetune_group = {}
        # hook_mods = []
        for qmod, fmod in zip(self._quant_model.modules(),
                              self._float_model.modules()):
            if hasattr(qmod, "node"):
                if (self.quantizer.configer.is_node_quantizable(
                        qmod.node, False) and len(qmod.node.op.params) > 0):
                    finetune_group[qmod.node] = [qmod, fmod]

                    # hook_mods.append(fmod)
        # self.hook_cache_output(hook_mods, hook_type="single")

        for node, module_pair in finetune_group.items():
            # if self.quantizer.configer.is_node_quantizable(node, False) and \
            #   len(node.op.params) > 0:
            quant_layer, float_layer = module_pair
            pn_node = self.graph.parents(node)[0]
            handlers = self.hook_cache_output([pn_node.module],
                                              hook_type="single")
            layer_inputs = []
            with torch.no_grad():
                for input_args in zip(*net_inputs):
                    new_input_args = []
                    for ip in input_args:
                        if isinstance(ip, torch.Tensor):
                            new_input_args.append(ip.to(device))
                    _ = self.quant_model(*new_input_args)

                    layer_inputs.append(
                        self.cached_output[pn_node.module].detach().cpu())
            self.clean_hooks(handlers)
            del self.cached_output[pn_node.module]
            #print(f"Tuning {node.name}")
            net_loss = self.optimize_layer(node, float_layer, layer_inputs,
                                           layer_act_group, net_inputs,
                                           net_loss, last_quant_mods, device)
            del layer_inputs
            torch.cuda.empty_cache()

        # recover quantizer status
        for node in self.graph.nodes:
            for _, fp_history in self.quantizer.fp_history.items():
                if node.name in fp_history:
                    fp_history[node.name].clear()
        for mod in self.quant_model.modules():
            if hasattr(mod, "param_quantized"):
                setattr(mod, "param_quantized", False)
        for mod in self.quant_model.modules():
            if hasattr(mod, "param_saved"):
                setattr(mod, "param_saved", False)
        self.quantizer.quant_mode = 1
        set_option_value("nndct_param_corr", opt_bak_param_corr)

        # export finetuned parameters
        self.quantizer.export_param()
Beispiel #4
0
    def finetune(self, run_fn, run_args):
        if self.quantizer.quant_mode == 2:
            NndctScreenLogger().warning(
                f"Finetune function will be ignored in test mode!")
            return
        NndctScreenLogger().info(
            f"=>Preparing data for fast finetuning module parameters ...")

        # memory status
        total_m, *_, available_m = list(
            map(lambda x: x / 1024,
                map(int,
                    os.popen('free -t -m').readlines()[1].split()[1:])))
        NndctScreenLogger().info(
            f"Mem status(total mem: {total_m:.2f}G, available mem: {available_m:.2f}G)."
        )

        NndctScreenLogger().info(
            f"=>Preparing data for fast finetuning module parameters ...")
        # backup option value
        opt_bak_param_corr = NndctOption.nndct_param_corr.value
        set_option_value("nndct_param_corr", 0)

        # cache input and output
        #print("**** cache input and output")
        last_quant_nodes = self.collect_last_quant_nodes()
        with torch.no_grad():
            cache_layers = []
            monitor_layers = []
            for node in self.graph.nodes:
                if node.op.type == NNDCT_OP.INPUT or node in last_quant_nodes:
                    cache_layers.append(node.module)
                elif self.quantizer.configer.is_conv_like(node):
                    monitor_layers.append(node.module)

            monitor_handlers = self.hook_memory_monitor(monitor_layers)
            cache_handlers = self.hook_cache_output(cache_layers,
                                                    monitor_mem=True)
            set_option_value("nndct_quant_off", True)
            run_fn(*run_args)
            # memory statistics
            total_memory_cost = 0.0
            for layer in cache_layers:
                total_memory_cost += self._mem_count[layer]
                del self._mem_count[layer]

            total_memory_cost += 2 * max(self._mem_count.values())

            self.clean_hooks(monitor_handlers + cache_handlers)

        torch.cuda.empty_cache()
        NndctScreenLogger().info(
            f"Mem cost by fast finetuning: {total_memory_cost:.2f}G.")
        if total_memory_cost > 0.8 * available_m:
            NndctScreenLogger().warning(
                f"There is not enought memory for fast finetuning and this process will be ignored!.Try to use a smaller calibration dataset."
            )
            return
        # calibration to get a set of quantization steps
        #print("****calibration to get float model tensor values")
        for mod in self.quant_model.modules():
            if hasattr(mod, "param_quantized"):
                setattr(mod, "param_quantized", False)

        # evaluation to get float model tensors
        set_option_value("nndct_quant_off", False)
        with torch.no_grad():
            run_fn(*run_args)
        torch.cuda.empty_cache()

        #print("****Parameter finetuning")
        NndctScreenLogger().info(
            f"=>Fast finetuning module parameters for better quantization accuracy..."
        )
        device = GLOBAL_MAP.get_ele(NNDCT_KEYS.QUANT_DEVICE)
        graph_searcher = GraphSearcher(self.graph)
        node_sets = graph_searcher.find_nodes_from_type([
            PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.HSWISH]),
            PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.HSIGMOID]),
            PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.CONV2D, NNDCT_OP.RELU6]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.HSWISH]),
            PatternType(
                pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.HSIGMOID]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV2D, NNDCT_OP.RELU6]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.HSWISH]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.HSIGMOID]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE2D, NNDCT_OP.RELU6]),
            PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.HSWISH]),
            PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.HSIGMOID]),
            PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.CONV3D, NNDCT_OP.RELU6]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.HSWISH]),
            PatternType(
                pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.HSIGMOID]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.DEPTHWISE_CONV3D, NNDCT_OP.RELU6]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.HSWISH]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.HSIGMOID]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.RELU]),
            PatternType(pattern=[NNDCT_OP.CONVTRANSPOSE3D, NNDCT_OP.RELU6]),
        ])

        layer_act_group = {}
        for _, node_list in node_sets.items():
            for nodeset in node_list:
                conv, act = nodeset
                layer_act_group[conv] = act

        # to avoid quantization steps change among parameter finetuning
        self.quantizer.quant_mode = 2

        net_inputs = []
        for node in self.input_nodes:
            cached_net_input = [
                out for out in self.cached_outputs[node.module]
            ]
            net_inputs.append(cached_net_input)

        # last_quant_nodes = self.collect_last_quant_nodes()
        last_quant_mods = [node.module for node in last_quant_nodes]

        handlers = self.hook_cache_output(last_quant_mods, hook_type="single")
        net_loss = self.eval_loss(net_inputs, last_quant_mods, device)
        self.clean_hooks(handlers)
        # model.clean_hooks()
        torch.cuda.empty_cache()

        finetune_group = {}
        # hook_mods = []
        for qmod, fmod in zip(self._quant_model.modules(),
                              self._float_model.modules()):
            if hasattr(qmod, "node"):
                if (self.quantizer.configer.is_node_quantizable(
                        qmod.node, False) and len(qmod.node.op.params) > 0):
                    finetune_group[qmod.node] = [qmod, fmod]

                    # hook_mods.append(fmod)
        # self.hook_cache_output(hook_mods, hook_type="single")

        #for node, module_pair in finetune_group.items():
        for idx, (node,
                  module_pair) in tqdm(enumerate(finetune_group.items()),
                                       total=len(finetune_group.items())):
            # if self.quantizer.configer.is_node_quantizable(node, False) and \
            #   len(node.op.params) > 0:
            quant_layer, float_layer = module_pair
            pn_node = self.graph.parents(node)[0]
            handlers = self.hook_cache_output([pn_node.module],
                                              hook_type="single")
            layer_inputs = []
            with torch.no_grad():
                for input_args in zip(*net_inputs):
                    new_input_args = []
                    for ip in input_args:
                        if isinstance(ip, torch.Tensor):
                            new_input_args.append(ip.to(device))
                    _ = self.quant_model(*new_input_args)

                    layer_inputs.append(
                        self.cached_output[pn_node.module].detach().cpu())
            self.clean_hooks(handlers)
            del self.cached_output[pn_node.module]
            #print(f"Tuning {node.name}")
            net_loss = self.optimize_layer(node, float_layer, layer_inputs,
                                           layer_act_group, net_inputs,
                                           net_loss, last_quant_mods, device)
            # print(f"{node.name}:{net_loss}")
            del layer_inputs
            torch.cuda.empty_cache()

        # recover quantizer status
        for node in self.graph.nodes:
            for _, config_history in self.quantizer.config_history.items():
                if node.name in config_history:
                    config_history[node.name].clear()
        for mod in self.quant_model.modules():
            if hasattr(mod, "param_quantized"):
                setattr(mod, "param_quantized", False)
        for mod in self.quant_model.modules():
            if hasattr(mod, "param_saved"):
                setattr(mod, "param_saved", False)
        self.quantizer.quant_mode = 1
        set_option_value("nndct_param_corr", opt_bak_param_corr)

        NndctScreenLogger().info(f"=>Export fast finetuned parameters ...")
        # export finetuned parameters
        self.quantizer.export_param()
Beispiel #5
0
 def __enter__(self):
     self._param_corr = NndctOption.nndct_param_corr.value
     set_option_value("nndct_param_corr", 0)
Beispiel #6
0
 def __exit__(self, *args):
     set_option_value("nndct_quant_off", False)
Beispiel #7
0
 def __enter__(self):
     set_option_value("nndct_quant_off", True)