Ejemplo n.º 1
0
  def _insert_quantizer(self, model_topo, allow_reused_module):
    """Insert quantizer for quantizing input/output of a module.
      The quantization of weight/bias is handled by quantized module itself.
    """

    quantized_modules = set()
    for node in model_topo.nodes:
      qconfig = node.qconfig
      if not qconfig:
        continue

      # Check if there are parameterized modules that have not been
      # transformed to the corresponding quantized version.
      if qconfig.weight or qconfig.bias:
        if not hasattr(node.module, 'is_quantized'):
          raise NotImplementedError(
              ('The quantization of {} not implemented '
               'yet. (Node name: {})').format(type(node.module), node.name))

      if node.name in quantized_modules:
        continue

      logging.vlog(
          3,
          'Inserting quantizer for node {}: {}'.format(node.graph_node.name,
                                                       qconfig))
      quantized_modules.add(node.name)
      if qconfig.input:
        # Reserved support for multiple inputs, currently will always be 0.
        quantize_input(node.module, 0, qconfig.input)

      if qconfig.output:
        quantize_output(node.module, qconfig.output)
Ejemplo n.º 2
0
def insert_quantizer(model_topo):
  """Insert quantizer for quantizing input/output of a module.
  The quantization of weight/bias is handled by quantized module itself.
  """

  quantized_modules = set()
  for node in model_topo.nodes:
    rt_spec = node.spec
    if not rt_spec:
      continue

    # Check if there are parameterized modules that have not been
    # transformed to the corresponding quantized version.
    #if qconfig.weight or qconfig.bias:
    #  if not hasattr(node.module, 'is_quantized'):
    #    raise NotImplementedError(
    #        ('The quantization of {} not implemented '
    #         'yet. (Node name: {})').format(type(node.module), node.name))

    if node.name in quantized_modules:
      continue

    logging.vlog(
        3, 'Inserting quantizer for node {}: {}'.format(node.graph_node.name,
                                                        rt_spec))
    quantized_modules.add(node.name)
    for index, quantizer in enumerate(rt_spec.input_quantizers):
      quantize_input(node.module, index, quantizer)

    output_quantizers = rt_spec.output_quantizers
    if len(output_quantizers) > 1:
      raise NotImplementedError('Multiple outputs tensor not supported yet.')

    if output_quantizers:
      quantize_output(node.module, output_quantizers[0])
Ejemplo n.º 3
0
  def _to_deployable(self, trained_model, output_dir):
    if not self._quant_config or self._module_map is None:
      raise RuntimeError('Must call "trainable_model" first.')

    if hasattr(trained_model, 'conv_bn_fused') and getattr(
        trained_model, 'conv_bn_fused'):
      raise RuntimeError(
          'Not allowed to convert a fused model to a deployable model.')

    # Copy trained parameters from transformed model to original float model.
    orig_state_dict = self._model.state_dict()
    trained_state_dict = trained_model.state_dict()
    state_dict = {}
    for key in orig_state_dict:
      if '.' in key:
        module_name, weight_name = key.rsplit('.', 1)
      else:
        # Such as 'global_step'.
        module_name, weight_name = None, key
      if module_name in self._module_map:
        # Currently only for bn.
        # conv1.0.0.bn.weight -> conv1.0.1.weight
        trained_module_name = self._module_map[module_name]
        trained_key = '.'.join([trained_module_name, weight_name])
      else:
        trained_key = key
      state_dict[key] = trained_state_dict[trained_key]
      logging.vlog(3, 'state dict of {} is from {}'.format(key, trained_key))
    model = copy.deepcopy(self._model)
    model.load_state_dict(state_dict)
    model.eval()

    qprocessor = qproc.TorchQuantProcessor(
        'test',
        model,
        self._inputs,
        output_dir=self._tmp_qat_dir,
        bitwidth_w=self._bitwidth,
        bitwidth_a=self._bitwidth,
        mix_bit=self._mix_bit,
        device=self._device)

    quantizer = qprocessor.quantizer
    self._fill_in_quant_config(quantizer)

    sub_dir = os.path.join(output_dir, 'test')
    io_util.create_work_dir(sub_dir)
    # Must set adjust_pos=False first, because quantizer will modify its
    # quant info inplace when adjust_pos=True.
    # Export original (not adjusted yet) quant info for testing deployable
    # model and the accuracy should be the same with the trainable model.
    quantizer.export_quant_config(
        os.path.join(sub_dir, _QUANT_INFO_FILE_NAME), adjust_pos=False)
    quantizer.export_quant_config(
        os.path.join(output_dir, _QUANT_INFO_FILE_NAME), adjust_pos=True)

    self._qprocessor = qprocessor
    return model
Ejemplo n.º 4
0
  def __init__(self,
               model,
               inputs,
               bitwidth,
               mix_bit=False,
               device=torch.device("cuda")):

    if isinstance(model, torch.nn.DataParallel):
      raise ValueError('DataParallel object is not allowed.')

    # turn off options optimization for following quantization
    option_util.set_option_value("nndct_quant_opt", 0)
    option_util.set_option_value("nndct_param_corr", False)
    option_util.set_option_value("nndct_equalization", False)

    self._model = model
    self._inputs = inputs
    self._bitwidth = bitwidth
    self._mix_bit = mix_bit
    self._device = device

    # Original module name to transformed module name.
    # We can use it to convert the transformed model's state_dict keys
    # so that the original float model can load it.
    self._module_map = None

    self._trainable_model = None
    self._tmp_qat_dir = '.qat'

    qprocessor = qproc.TorchQuantProcessor(
        'calib',
        model,
        inputs,
        output_dir=self._tmp_qat_dir,
        bitwidth_w=self._bitwidth,
        bitwidth_a=self._bitwidth,
        mix_bit=mix_bit,
        device=device)
    quantizer = qprocessor.quantizer
    self._torch_quantizer = quantizer

    self._qinfo_keys = [
        TensorTypes.PARAM, TensorTypes.INPUT, TensorTypes.OUTPUT
    ]

    # Use hard-coded value to fill in fp_pos and export quant config,
    # so that we can initialize a new TorchQuantProcessor in 'test' mode later.
    quant_config = quantizer.quant_config
    for key, group in quant_config.items():
      if key not in self._qinfo_keys:
        continue
      for item in group:
        group[item][-1] = 4
    quantizer.export_quant_config(adjust_pos=False)

    # Use quantizer's graph to build param_to_node as the quant_info is
    # generated from the quantizer's graph.
    # For example, the param 'ResNet::conv.bias' only exist in the quantizer's
    # graph because it comes from the fused conv + bias.
    self._tensor_to_node = {}
    graph = quantizer.Nndctgraph
    for node in graph.nodes:
      for name, tensor in node.op.params.items():
        self._tensor_to_node[tensor.name] = (node.name, name)

    parser = parse.TorchParser()
    self._graph = parser(self._model._get_name(), self._model, self._inputs)

    quant_optimizer = QuantOptimizer()
    if NndctOption.nndct_partition_mode.value > 0:
      quant_optimizer._tag_quant_nodes_v2(self._raph)
    else:
      quant_optimizer._tag_quant_nodes(self._graph)

    def get_bitwidth(quant_info):
      return quant_info[0] if quant_info[0] == 8 else bitwidth

    # Create quantizer for each item in quant config.
    self._node_to_qconfig = {}
    self._quant_config = copy.deepcopy(quant_config)
    for name, group in self._quant_config.items():
      if name not in self._qinfo_keys:
        continue
      for key, qinfo in group.items():
        if name == TensorTypes.PARAM:
          node, param = self._tensor_to_node[key]
          attr = ModuleHooker._parameter_map[param]
          tensor_type = 'weight'
        else:
          node, attr = key, None
          tensor_type = 'act'

        tqt_quantizer = TQTQuantizer(get_bitwidth(qinfo), tensor_type)
        qconfig = self._node_to_qconfig.get(node, config_mod.LayerRuntimeSpec())
        if name == TensorTypes.PARAM:
          qconfig.add_weight_quantizer(attr, tqt_quantizer)
        elif name == TensorTypes.INPUT:
          qconfig.add_input_quantizer(tqt_quantizer)
        else:
          qconfig.add_output_quantizer(tqt_quantizer)

        self._node_to_qconfig[node] = qconfig
        self._quant_config[name][key] = (node, attr)
        logging.vlog(2, '[{}][{}] = ({}, {})'.format(name, key, node, attr))
Ejemplo n.º 5
0
  def _assert_valid_model(self, allow_reused_module):
    # If two or more nodes point to a same module, then we will let them
    # use the same qconfig.
    module_to_qconfig = {}
    for node in self._graph.nodes:
      module_name = mod_util.module_name_from_node(node)
      if not module_name or node.name not in self._node_to_qconfig:
        continue

      if module_name in module_to_qconfig:
        if allow_reused_module:
          self._node_to_qconfig[node.name] = module_to_qconfig[module_name]
          logging.warn(
              ('Reused module ({}) may lead to low accuracy of QAT, '
               'make sure this is what you expect.').format(module_name))
        else:
          raise ValueError(
              ('Quantized module "{}" has been called multiple '
               'times in forward pass. If you want to share quantized '
               'parameters in multiple calls, call trainable_model with '
               '"allow_reused_module=True"').format(module_name))
      module_to_qconfig[module_name] = self._node_to_qconfig[node.name]

    # Make sure all quantizable operations are instance of torch.nn.Module.
    replacement_map = {
        OpTypes.ADD: ('torch.add/+', functional.Add),
        OpTypes.CONCAT: ('torch.cat', functional.Cat),
        OpTypes.MAX: ('torch.max', functional.Max),
        OpTypes.PAD: ('torch.nn.functional.pad', functional.Pad),
        OpTypes.RELU: ('torch.nn.functional.relu', torch.nn.ReLU),
        OpTypes.SUM: ('torch.sum', functional.Sum),
    }

    for name, group in self._quant_config.items():
      if name not in self._qinfo_keys:
        continue
      for key in group:
        node_name, _ = self._quant_config[name][key]

        module = mod_util.get_module_by_node(self._model, node_name)
        node = self._graph.node(node_name)

        module_cls = type(module) if module else None
        if node.op.type in replacement_map:
          op, target_cls = replacement_map[node.op.type]
          if module_cls != target_cls:
            raise ValueError(
                ('Quantized operation({}) must be instance '
                 'of "torch.nn.Module", please replace {} with {}').format(
                     node.name, op, target_cls))

        # A quantized op must be implemented as a module.
        if not module:
          if node.op.type == OpTypes.INPUT:
            raise ValueError(
                ('Input is not quantized. Please use QuantStub/DeQuantStub to '
                 'define quantization scope.'))
          else:
            raise ValueError(
                ('Can not quantize node "{}({})" as it is not a '
                 'torch.nn.Module object, please re-implement this operation '
                 'as a module.').format(node.name, node.op.type))

        torch_op_type = py_utils.get_torch_op_type(node.op.type)
        torch_op_attr = py_utils.get_torch_op_attr(torch_op_type)
        if not torch_op_attr.op_name.startswith('torch'):
          logging.vlog(1,
                       'Non-torch op found: {}'.format(torch_op_attr.op_name))
          continue

        # Check if we get the correct module.
        op_type_name = torch_op_attr.op_name.split('.')[-1]
        logging.vlog(
            1, '{}({}): {} vs. {}'.format(node.name, node.op.type,
                                          module_cls.__name__,
                                          torch_op_attr.op_name))
        if not module_cls.__module__.startswith(
            'pytorch_nndct') and module_cls.__name__ != op_type_name:
          raise ValueError(('{} is a quantized operation, please re-implement '
                            'your op as a nn.Module (Node: {})').format(
                                torch_op_attr.op_name, node_name))