Ejemplo n.º 1
0
 def _do_map(output_name, node_name):
     if not output_name == node_name:
         if not GLOBAL_MAP.get_ele(NNDCT_KEYS.OUTPUT_TO_NODE_MAP):
             GLOBAL_MAP.set_map(NNDCT_KEYS.OUTPUT_TO_NODE_MAP, {})
         if not GLOBAL_MAP.get_ele(NNDCT_KEYS.NODE_TO_OUTPUT_MAP):
             GLOBAL_MAP.set_map(NNDCT_KEYS.NODE_TO_OUTPUT_MAP, {})
         #map output to node
         output_to_node_map = GLOBAL_MAP.get_ele(
             NNDCT_KEYS.OUTPUT_TO_NODE_MAP)
         if not output_name in output_to_node_map:
             nndct_debug_print(
                 "<map_output_and_node> map out {} and node{}".format(
                     output_name, node_name),
                 level=NNDCT_DEBUG_LVL.BUILD_GRAPH)
             output_to_node_map[output_name] = node_name
         else:
             assert output_to_node_map[
                 output_name] == node_name, "restored node name for output_name {} is {}, meet new node name {}".format(
                     output_name, output_to_node_map[output_name],
                     node_name)
         #add output to list keyed by node_name
         node_to_output_map = GLOBAL_MAP.get_ele(
             NNDCT_KEYS.NODE_TO_OUTPUT_MAP)
         if not node_name in node_to_output_map:
             node_to_output_map[node_name] = [output_name]
         else:
             node_to_output_map[node_name].append(output_name)
Ejemplo n.º 2
0
 def custom_op(self, node, *args):
     node2caller = GLOBAL_MAP.get_ele(NNDCT_KEYS.NODE_CALLER_MAP)
     if node2caller is None:
         node2caller: Dict[str, Callable] = {}
         GLOBAL_MAP.set_map(NNDCT_KEYS.NODE_CALLER_MAP, node2caller)
     node2caller[node.name] = node.caller
     op = TorchCustomOperation(node.raw_kind, node.raw_kind)
     for i, arg in enumerate(args):
         op.set_config(str(i), arg)
     attrs = GLOBAL_MAP.get_ele(NNDCT_KEYS.CUSTOM_OP_ATTRS_MAP).get(
         node.raw_kind, None)
     if attrs:
         attr_vals = args[len(args) - len(attrs):]
         for name, val in zip(attrs, attr_vals):
             op.set_attr_by_name(name, val)
     return op
Ejemplo n.º 3
0
    def _init_quant_env():
        nonlocal quant_mode
        if NndctOption.nndct_quant_mode.value > 0:
            quant_mode = NndctOption.nndct_quant_mode.value

        if quant_mode == 1:
            NndctScreenLogger().info(
                f"Quantization calibration process start up...")
        elif quant_mode == 2:
            NndctScreenLogger().info(f"Quantization test process start up...")

        quantizer = TORCHQuantizer(quant_mode, output_dir, bitwidth_w,
                                   bitwidth_a)
        GLOBAL_MAP.set_map(NNDCT_KEYS.QUANTIZER, quantizer)
        GLOBAL_MAP.set_map(NNDCT_KEYS.QUANT_MODE, quant_mode)
        return quantizer, quant_mode
Ejemplo n.º 4
0
def build_aten_torch_ops_table():
  op_gathering_fns = (_get_tensor_ops, 
                      _get_nn_functional_ops, 
                      _get_torchscript_builtins, 
                      _get_global_builtins, 
                      _get_math_builtins,
                      )
  schema2torchop = GLOBAL_MAP.get_ele(NNDCT_KEYS.TORCH_SCHEMA_OP_TABLE)
  # schema_lut = GLOBAL_MAP.get_ele(NNDCT_KEYS.SCHEMA_LUT)
  if not schema2torchop:
    
    schema2torchop: Dict[str, TorchOp] = {}
    GLOBAL_MAP.set_map(NNDCT_KEYS.TORCH_SCHEMA_OP_TABLE, schema2torchop)

    # schema_lut: Dict[Tuple(str, int), "Schema"] = {}
    for fn in op_gathering_fns:
      fn()
Ejemplo n.º 5
0
    def convert_to_deployable(self, trained_model, mix_bit=False):
        if not self._qinfo_to_quantizer or not self._module_map:
            raise RuntimeError('Must call "trainable_model" first.')

        # Copy trained parameters from transformed model to original float model.
        orig_state_dict = self._model.state_dict()
        trained_state_dict = trained_model.state_dict()
        state_dict = {}
        for key in orig_state_dict.keys():
            module_name, weight_name, = key.rsplit('.', 1)
            if module_name in self._module_map:
                trained_module_name = self._module_map[module_name]
                trained_key = '.'.join([trained_module_name, weight_name])
            else:
                trained_key = key
            state_dict[key] = trained_state_dict[trained_key]
        model = copy.deepcopy(self._model)
        model.load_state_dict(state_dict)
        model.eval()
        '''
    inputs = dummy_inputs(self._input_specs)
    qprocessor = qproc.TorchQuantProcessor(
        'test',
        model,
        [inp.cuda() for inp in inputs],
        mix_bit=mix_bit,
        device=torch.device('cuda'))
    '''
        inputs = self._input_args
        qprocessor = qproc.TorchQuantProcessor('test',
                                               model,
                                               inputs,
                                               mix_bit=mix_bit,
                                               device=torch.device('cuda'))

        quantizer = qprocessor.quantizer
        self._fill_in_quant_info(quantizer, self._qinfo_to_quantizer)
        quantizer.export_quant_config()

        quant_model = quantizer.quant_model
        quant_model.dump_xmodel = dump_xmodel
        self.deploy_quantizer = quantizer
        GLOBAL_MAP.set_map(NNDCT_KEYS.QUANTIZER, quantizer)
        NndctScreenLogger().info(f"=>Deployable model is generated.")
Ejemplo n.º 6
0
def tf_quantizer(model,
                 input_signature,
                 quant_mode: str = "calib",
                 output_dir: str = "quantize_result",
                 bitwidth: int = 8):
    #initialize quant mode
    qmode = _init_quant_mode(quant_mode)

    # turn off weights equalization and bias correction
    option_util.set_option_value("nndct_param_corr", False)
    option_util.set_option_value("nndct_equalization", False)

    # lstm IP only support 16 bit activation
    quantizer = TFQuantizer(qmode, output_dir, bitwidth, 16)
    GLOBAL_MAP.set_map(NNDCT_KEYS.QUANTIZER, quantizer)
    GLOBAL_MAP.set_map(NNDCT_KEYS.QUANT_MODE, qmode)

    graph = parser.from_keras_model(model, input_signature)
    quant_model, layer_nodes = builder.KerasBuilder(graph).build(
        os.path.join(output_dir, model.name + '_quant.py'), quantized=True)

    rebuilding_results = _maybe_rebuild_rnn(quant_model)
    if rebuilding_results:
        cell_graphs = []
        cell_layer_nodes = []
        for graph, layer_nodes in rebuilding_results:
            cell_graphs.append(graph)
            cell_layer_nodes.extend(layer_nodes)
            quantizer.add_rnn_cell_graph('forward', graph)

        graph = _merge_cell_graphs(cell_graphs)
        layer_nodes = cell_layer_nodes
        # TODO(yuwang): Support backward direction.

    export_file = os.path.join(output_dir, 'merged_graph.pb')
    graph_utils.maybe_export_graph(export_file, graph)

    lstm = True if len(rebuilding_results) > 0 else False
    quantizer.setup(graph, lstm=lstm)
    quantizer.load_node_to_layer(layer_nodes, quant_model)

    return quantizer
Ejemplo n.º 7
0
    def default(self, node, *args):
        schema2torchop = GLOBAL_MAP.get_ele(NNDCT_KEYS.TORCH_SCHEMA_OP_TABLE)
        schema_handler = SchemaHelper(node.schema)
        torchop = schema2torchop.get(schema_handler.toString(), None)
        if torchop is None:
            op = TorchUnknownOperation(node.raw_kind)
            return op
        node2caller = GLOBAL_MAP.get_ele(NNDCT_KEYS.NODE_CALLER_MAP)
        if node2caller is None:
            node2caller: Dict[str, Callable] = {}
            GLOBAL_MAP.set_map(NNDCT_KEYS.NODE_CALLER_MAP, node2caller)
        node2caller[node.name] = torchop.caller
        op = TorchBaseOperation(schema_handler.op_name,
                                torchop.name,
                                schema=node.schema)
        # op.set_caller(torchop.caller)
        assert len(args) == len(schema_handler.get_arguments())
        if len(args) == 1:
            return op
        arg_name_convertor = {"self": "input"}
        for inp, arg in zip(args, schema_handler.get_arguments()):
            arg_name = schema_handler.arg_name(arg)
            if torchop.op_class_type == TorchOpClassType.TENSOR and arg_name == "self":
                continue
            if arg_name in ["layout", "memory_format", "pin_memory"]:
                continue
            config_name = arg_name_convertor.get(arg_name, arg_name)
            if convert_type_str(schema_handler.arg_type(arg)).replace(
                    "?", "") == "bool":
                inp = bool(inp) if inp is not None else inp
            if convert_type_str(schema_handler.arg_type(arg)).replace(
                    "?", "") == "str":
                inp = f"'{inp}'" if inp is not None else inp

            if arg_name == "device":
                inp = f"'{self._device_type}'"
            if arg_name == "dtype":
                inp = scalar_type_to_pytorch_type[
                    inp] if inp is not None else inp
            op.set_config(config_name, inp)
        return op
Ejemplo n.º 8
0
 def export_traced_torch_script(self, output_dir, verbose=False):
     torch_version = torch.__version__.split('.')
     if int(torch_version[0]) == 1 and int(torch_version[1]) < 7:
         NndctScreenLogger().error(
             f'Only supprt exporting torch script with pytorch 1.7 and later version'
         )
         return
     self.quantizer.reset_status_for_exporting()
     device = GLOBAL_MAP.get_ele(NNDCT_KEYS.QUANT_DEVICE)
     force_cpu = os.getenv('NNDCT_FORCE_CPU_DUMP')
     if force_cpu is not None:
         device = torch.device('cpu')
         GLOBAL_MAP.set_map(NNDCT_KEYS.QUANT_DEVICE, device)
     model, input_args = to_device(self.quantizer.quant_model,
                                   self._example_inputs, device)
     script_module = torch.jit.trace(model, input_args, check_trace=False)
     output_file = os.path.join(
         output_dir, f"{self.quantizer.quant_model._get_name()}_int.pt")
     if verbose is True:
         print(script_module.inlined_graph)
     torch.jit.save(script_module, output_file)
Ejemplo n.º 9
0
    def __init__(self,
                 quant_mode: str,
                 module: torch.nn.Module,
                 input_args: Union[torch.Tensor, Sequence[Any]] = None,
                 state_dict_file: Optional[str] = None,
                 output_dir: str = "quantize_result",
                 bitwidth_w: int = 8,
                 bitwidth_a: int = 8,
                 mix_bit: bool = False,
                 device: torch.device = torch.device("cuda"),
                 lstm_app: bool = False):
        # Check arguments type
        self._check_args(module, input_args)

        # Check device available
        if device.type == "cuda":
            if not (torch.cuda.is_available() and "CUDA_HOME" in os.environ):
                device = torch.device("cpu")
                NndctScreenLogger().warning(
                    f"CUDA is not available, change device to CPU")

        # Transform torch module to quantized module format
        nndct_utils.create_work_dir(output_dir)

        # Create a quantizer object, which can control all quantization flow,
        quant_strategy = DefaultQstrategy(bits_weight=bitwidth_w,
                                          bits_bias=bitwidth_a,
                                          bits_activation=bitwidth_a,
                                          mix_bit=mix_bit)
        quantizer, qmode = self._init_quant_env(quant_mode, output_dir,
                                                quant_strategy)
        GLOBAL_MAP.set_map(NNDCT_KEYS.QUANTIZER, quantizer)
        GLOBAL_MAP.set_map(NNDCT_KEYS.QUANT_MODE, qmode)
        GLOBAL_MAP.set_map(NNDCT_KEYS.QUANT_DEVICE, device)
        if lstm_app: option_util.set_option_value("nndct_cv_app", False)
        else: option_util.set_option_value("nndct_cv_app", True)

        # Prepare quantizable module
        quant_module, graph = prepare_quantizable_module(
            module=module,
            input_args=input_args,
            export_folder=output_dir,
            state_dict_file=state_dict_file,
            quant_mode=qmode,
            device=device)

        # enable record outputs of per layer
        if qmode > 1:
            register_output_hook(quant_module, record_once=True)
            set_outputs_recorder_status(quant_module, True)

        # intialize quantizer
        quantizer.setup(graph, False, lstm_app)

        # hook module with quantizer
        # connect_module_with_quantizer(quant_module, quantizer)
        quantizer.quant_model = quant_module

        self.quantizer = quantizer
        self.adaquant = None
Ejemplo n.º 10
0
  def __init__(self,
               quant_mode: str,
               module: torch.nn.Module,
               input_args: Union[torch.Tensor, Sequence[Any]] = None,
               state_dict_file: Optional[str] = None,
               output_dir: str = "quantize_result",
               bitwidth_w: int = 8,
               bitwidth_a: int = 8,
               device: torch.device = torch.device("cuda"),
               lstm_app: bool = True):
    self._export_folder = output_dir
    # Check arguments type
    self._check_args(module)
    
    # Check device available
    if device.type == "cuda":
      if not (torch.cuda.is_available() and "CUDA_HOME" in os.environ):
        device = torch.device("cpu")
        NndctScreenLogger().warning(f"CUDA is not available, change device to CPU")
    
    # Transform torch module to quantized module format
    nndct_utils.create_work_dir(output_dir)
    
    # turn off weights equalization and bias correction
    option_util.set_option_value("nndct_quant_opt", 0)
    option_util.set_option_value("nndct_param_corr", False)
    option_util.set_option_value("nndct_equalization", False)
    
    # Create a quantizer object, which can control all quantization flow,
    #if quant_strategy == None:
    quant_strategy = DefaultQstrategy(bits_weight=bitwidth_w,
                                      bits_bias=bitwidth_w,
                                      bits_activation=bitwidth_a)
    quantizer, qmode = self._init_quant_env(quant_mode, 
                                            output_dir,
                                            quant_strategy)
    GLOBAL_MAP.set_map(NNDCT_KEYS.QUANTIZER, quantizer)
    GLOBAL_MAP.set_map(NNDCT_KEYS.QUANT_MODE, qmode)
    GLOBAL_MAP.set_map(NNDCT_KEYS.QUANT_DEVICE, device)
    
    standard_RNNs, customized_RNNs = self._analyse_module(module)

    if len(standard_RNNs) == 0 and len(customized_RNNs) == 0:
      raise RuntimeError(
          f"The top module '{module._get_name()}' should have one LSTM module at least."
      )

    self._modules_info = defaultdict(dict)

    # process customized Lstm
    for layer_name, layer_module in customized_RNNs.items():
      for cell_name, cell_module in layer_module.named_children():
        lstm_direction = "forward" if layer_module.go_forward else "backward"
        full_cell_name = ".".join([layer_name, cell_name])
        layer_graph = self._get_customized_LSTM_graph(full_cell_name,
                                                      cell_module,
                                                      layer_module.input_size,
                                                      layer_module.hidden_size,
                                                      layer_module.memory_size)
        self._modules_info[full_cell_name]["layers_graph"] = [{
            lstm_direction: layer_graph
        }]
        self._modules_info[full_cell_name]["stack_mode"] = None
        self._modules_info[full_cell_name]["layer_module"] = layer_module

    # process standard Lstm
    for name, rnn_module in standard_RNNs.items():
      layers_graph = self._get_standard_RNN_graph(
          graph_name=name, lstm_module=rnn_module)
      self._modules_info[name]["layers_graph"] = layers_graph
      self._modules_info[name]["input_size"] = [rnn_module.input_size
                                                ] * rnn_module.num_layers
      self._modules_info[name]["hidden_size"] = [rnn_module.hidden_size
                                                 ] * rnn_module.num_layers
      self._modules_info[name]["memory_size"] = [rnn_module.hidden_size
                                                 ] * rnn_module.num_layers
      self._modules_info[name][
          "stack_mode"] = "bidirectional" if rnn_module.bidirectional else "unidirectional"
      self._modules_info[name][
          "batch_first"] = True if rnn_module.batch_first is True else False

      if rnn_module.mode == 'LSTM':
        self._modules_info[name]["mode"] = "LSTM"
      elif rnn_module.mode == "GRU": 
        self._modules_info[name]["mode"] = "GRU"
    # merge multi graphs into a graph
    top_graph = self._merge_subgraphs()
    
    # turn on quantizer
    #if quant_mode:
    quantizer.setup(top_graph, rnn_front_end=True, lstm=True)
    
    # write and reload quantizable cell module
    module_graph_map = self._rebuild_layer_module()
    
    # replace float module with quantizale module
    for name, info in self._modules_info.items():
      if info["stack_mode"] is not None:
        self._build_stack_lstm_module(info)
      else:
        info["QLSTM"] = list(info["layers_module"][0].values())[0]
      module = self._insert_QuantLstm_in_top_module(module, name, info)

    # move modules info into layers info
    self._convert_modules_info_to_layers(module_graph_map)

    # hook module with quantizer
    # connect_module_with_quantizer(quant_module, quantizer)
    quantizer.quant_model = module

    self.quantizer = quantizer
Ejemplo n.º 11
0
  def quantize_modules(self, top_module: torch.nn.Module) -> torch.nn.Module:
    """
    `prepare quantizable LSTM sub modules.`
    
    Args:
        top_module (torch.nn.Module): Top Module in which LSTM need to do quantization
    
    Raises:
        RuntimeError: The top module should have one LSTM at least.
    
    Returns:
        torch.nn.Module: Top Module in which LSTM sub modules are transformed to quantizible module
    """

    standard_RNNs, customized_RNNs = self._analyse_module(top_module)

    if len(standard_RNNs) == 0 and len(customized_RNNs) == 0:
      raise RuntimeError(
          f"The top module '{top_module._get_name()}' should have one LSTM module at least."
      )

    nndct_utils.create_work_dir(self._export_folder)

    self._modules_info = defaultdict(dict)

    # process customized Lstm
    for layer_name, layer_module in customized_RNNs.items():
      for cell_name, cell_module in layer_module.named_children():
        lstm_direction = "forward" if layer_module.go_forward else "backward"
        full_cell_name = ".".join([layer_name, cell_name])
        layer_graph = self._get_customized_LSTM_graph(full_cell_name,
                                                      cell_module,
                                                      layer_module.input_size,
                                                      layer_module.hidden_size,
                                                      layer_module.memory_size)
        self._modules_info[full_cell_name]["layers_graph"] = [{
            lstm_direction: layer_graph
        }]
        self._modules_info[full_cell_name]["stack_mode"] = None
        self._modules_info[full_cell_name]["layer_module"] = layer_module

    # process standard Lstm
    for name, module in standard_RNNs.items():
      layers_graph = self._get_standard_RNN_graph(
          graph_name=name, lstm_module=module)
      self._modules_info[name]["layers_graph"] = layers_graph
      self._modules_info[name]["input_size"] = [module.input_size
                                                ] * module.num_layers
      self._modules_info[name]["hidden_size"] = [module.hidden_size
                                                 ] * module.num_layers
      self._modules_info[name]["memory_size"] = [module.hidden_size
                                                 ] * module.num_layers
      self._modules_info[name][
          "stack_mode"] = "bidirectional" if module.bidirectional else "unidirectional"
      self._modules_info[name][
          "batch_first"] = True if module.batch_first is True else False

      if module.mode == 'LSTM':
        self._modules_info[name]["mode"] = "LSTM"
      elif module.mode == "GRU": 
        self._modules_info[name]["mode"] = "GRU"
    # merge multi graphs into a graph
    top_graph = self._merge_subgraphs()
    
    # turn on quantizer
    if self._quant_mode:
      quantizer = TORCHQuantizer(self._quant_mode, self._export_folder,
                                 self._bit_w, self._bit_a)
      GLOBAL_MAP.set_map(NNDCT_KEYS.QUANTIZER, quantizer)
      GLOBAL_MAP.set_map(NNDCT_KEYS.QUANT_MODE, self._quant_mode)
      quantizer.setup(top_graph, lstm=True)
    
    # write and reload quantizable cell module
    module_graph_map = self._rebuild_layer_module()
    
    # hook quantizer and module
    if self._quant_mode is not None:
      self._hook_quant_module_with_quantizer(quantizer)
    
    # replace float module with quantizale module
    for name, info in self._modules_info.items():
      if info["stack_mode"] is not None:
        self._build_stack_lstm_module(info)
      else:
        info["QLSTM"] = list(info["layers_module"][0].values())[0]
      top_module = self._insert_QuantLstm_in_top_module(top_module, name, info)

    # move modules info into layers info
    self._convert_modules_info_to_layers(module_graph_map)

    return top_module
Ejemplo n.º 12
0
  def __init__(self,
               quant_mode: str,
               module: torch.nn.Module,
               input_args: Union[torch.Tensor, Sequence[Any]] = None,
               state_dict_file: Optional[str] = None,
               output_dir: str = "quantize_result",
               bitwidth_w: int = 8,
               bitwidth_a: int = 8,
               device: torch.device = torch.device("cuda"),
               lstm_app: bool = True,
               quant_config_file: Optional[str] = None):
    self._export_folder = output_dir
    # Check arguments type
    self._check_args(module)
    
    # Check device available
    if device.type == "cuda":
      if not (torch.cuda.is_available() and "CUDA_HOME" in os.environ):
        device = torch.device("cpu")
        NndctScreenLogger().warning(f"CUDA is not available, change device to CPU")
    
    # Transform torch module to quantized module format
    nndct_utils.create_work_dir(output_dir)
    
    # turn off weights equalization and bias correction
    option_util.set_option_value("nndct_quant_opt", 0)
    option_util.set_option_value("nndct_param_corr", False)
    option_util.set_option_value("nndct_equalization", False)
    option_util.set_option_value("nndct_cv_app", False)
    
    # Parse the quant config file
    QConfiger = RNNTorchQConfig()
    #if quant_config_file:
    QConfiger.parse_config_file(quant_config_file,
                                bit_width_w = bitwidth_w, 
                                bit_width_a = bitwidth_a)
    qconfig = QConfiger.qconfig
    #bitwidth_w = qconfig['weight']['bit_width']
    #bitwidth_b = qconfig['bias']['bit_width']
    #bitwidth_a = qconfig['activation']['bit_width']
    #mix_bit = qconfig['mix_bit'] 

    transformed_module = convert_lstm(module)
    script_module = torch.jit.script(transformed_module)
    quant_module, graph = prepare_quantizable_module(
        module=script_module,
        input_args=None,
        export_folder=output_dir,
        state_dict_file=state_dict_file,
        quant_mode=quant_mode,
        device=device)
    
    #qstrategy_factory =  QstrategyFactory()
    #quant_strategy = qstrategy_factory.create_qstrategy(qconfig) 

    #quant_strategy = DefaultQstrategy(bits_weight=bitwidth_w,
    #                                  bits_bias=bitwidth_w,
    #                                  bits_activation=bitwidth_a)
    
    quantizer, qmode = self._init_quant_env(quant_mode, 
                                            output_dir,
                                            qconfig,
                                            is_lstm=True)
    
    GLOBAL_MAP.set_map(NNDCT_KEYS.QUANTIZER, quantizer)
    GLOBAL_MAP.set_map(NNDCT_KEYS.QUANT_MODE, qmode)
    GLOBAL_MAP.set_map(NNDCT_KEYS.QUANT_DEVICE, device)
    GLOBAL_MAP.set_map(NNDCT_KEYS.QUANT_CONFIG, qconfig)

    quantizer.quant_model = quant_module.to(device)
    
    quantizer.setup(graph, rnn_front_end=True, lstm=True)

    self.quantizer = quantizer
Ejemplo n.º 13
0
    def __init__(self,
                 quant_mode: str,
                 module: torch.nn.Module,
                 input_args: Union[torch.Tensor, Sequence[Any]] = None,
                 state_dict_file: Optional[str] = None,
                 output_dir: str = "quantize_result",
                 bitwidth_w: int = 8,
                 bitwidth_a: int = 8,
                 mix_bit: bool = False,
                 device: torch.device = torch.device("cuda"),
                 lstm_app: bool = False,
                 custom_quant_ops: Optional[List[str]] = None,
                 quant_config_file: Optional[str] = None):
        # Check arguments type
        self._check_args(module, input_args)

        # Check device available
        if device.type == "cuda":
            if not (torch.cuda.is_available() and "CUDA_HOME" in os.environ):
                device = torch.device("cpu")
                NndctScreenLogger().warning(
                    f"CUDA is not available, change device to CPU")

        # Transform torch module to quantized module format
        nndct_utils.create_work_dir(output_dir)

        # Parse the quant config file
        QConfiger = TorchQConfig()
        #if quant_config_file:
        QConfiger.parse_config_file(quant_config_file,
                                    bit_width_w=bitwidth_w,
                                    bit_width_a=bitwidth_a,
                                    mix_bit=mix_bit)
        qconfig = QConfiger.qconfig
        #bitwidth_w = qconfig['weights']['bit_width']
        #bitwidth_b = qconfig['bias']['bit_width']
        #bitwidth_a = qconfig['activation']['bit_width']
        #mix_bit = qconfig['mix_bit']

        # Create a quantizer object, which can control all quantization flow,
        #qstrategy_factory = QstrategyFactory()
        #quant_strategy = qstrategy_factory.create_qstrategy(qconfig)
        #quant_strategy = DefaultQstrategy(bits_weight=bitwidth_w,
        #                                  bits_bias=bitwidth_a,
        #                                  bits_activation=bitwidth_a,
        #                                  mix_bit=mix_bit)
        quantizer, qmode = self._init_quant_env(quant_mode, output_dir,
                                                qconfig)

        GLOBAL_MAP.set_map(NNDCT_KEYS.QUANTIZER, quantizer)
        GLOBAL_MAP.set_map(NNDCT_KEYS.QUANT_MODE, qmode)
        GLOBAL_MAP.set_map(NNDCT_KEYS.QUANT_DEVICE, device)
        GLOBAL_MAP.set_map(NNDCT_KEYS.QUANT_CONFIG, qconfig)
        if lstm_app: option_util.set_option_value("nndct_cv_app", False)
        else: option_util.set_option_value("nndct_cv_app", True)

        # Prepare quantizable module

        quant_module, graph = prepare_quantizable_module(
            module=module,
            input_args=input_args,
            export_folder=output_dir,
            state_dict_file=state_dict_file,
            quant_mode=qmode,
            device=device)

        # enable record outputs of per layer
        if qmode > 1:
            register_output_hook(quant_module, record_once=True)
            set_outputs_recorder_status(quant_module, True)

        # intialize quantizer
        quantizer.setup(graph,
                        False,
                        lstm_app,
                        custom_quant_ops=custom_quant_ops)
        #if qmode > 1:
        #  quantizer.features_check()

        # hook module with quantizer
        # connect_module_with_quantizer(quant_module, quantizer)
        quantizer.quant_model = quant_module
        self._example_inputs = input_args

        self._lstm_app = lstm_app
        self.quantizer = quantizer
        self.adaquant = None

        # dump blob dist
        if NndctOption.nndct_visualize.value is True:
            visualize_tensors(quantizer.quant_model)