Ejemplo n.º 1
0
def insert_quantizer(model, node_to_qconfig):
    for node, qconfig in node_to_qconfig.items():
        if not qconfig.input and not qconfig.output:
            continue

        module_name = mod_util.module_name_from_node(node)
        if not module_name:
            raise ValueError(('Can not find module for node "{}"'
                              'Only module object can be quantized,'
                              'please re-implement this operation as a module.'
                              ).format(node))
        module = mod_util.get_module(model, module_name)
        if qconfig.input:
            # Reserved support for multiple inputs, currently will always be 0.
            quantize_input(module, 0, qconfig.input)

        if qconfig.output:
            quantize_output(module, qconfig.output)
Ejemplo n.º 2
0
  def _assert_valid_model(self, allow_reused_module):
    # If two or more nodes point to a same module, then we will let them
    # use the same qconfig.
    module_to_qconfig = {}
    for node in self._graph.nodes:
      module_name = mod_util.module_name_from_node(node)
      if not module_name or node.name not in self._node_to_qconfig:
        continue

      if module_name in module_to_qconfig:
        if allow_reused_module:
          self._node_to_qconfig[node.name] = module_to_qconfig[module_name]
          logging.warn(
              ('Reused module ({}) may lead to low accuracy of QAT, '
               'make sure this is what you expect.').format(module_name))
        else:
          raise ValueError(
              ('Quantized module "{}" has been called multiple '
               'times in forward pass. If you want to share quantized '
               'parameters in multiple calls, call trainable_model with '
               '"allow_reused_module=True"').format(module_name))
      module_to_qconfig[module_name] = self._node_to_qconfig[node.name]

    # Make sure all quantizable operations are instance of torch.nn.Module.
    replacement_map = {
        OpTypes.ADD: ('torch.add/+', functional.Add),
        OpTypes.CONCAT: ('torch.cat', functional.Cat),
        OpTypes.MAX: ('torch.max', functional.Max),
        OpTypes.PAD: ('torch.nn.functional.pad', functional.Pad),
        OpTypes.RELU: ('torch.nn.functional.relu', torch.nn.ReLU),
        OpTypes.SUM: ('torch.sum', functional.Sum),
    }

    for name, group in self._quant_config.items():
      if name not in self._qinfo_keys:
        continue
      for key in group:
        node_name, _ = self._quant_config[name][key]

        module = mod_util.get_module_by_node(self._model, node_name)
        node = self._graph.node(node_name)

        module_cls = type(module) if module else None
        if node.op.type in replacement_map:
          op, target_cls = replacement_map[node.op.type]
          if module_cls != target_cls:
            raise ValueError(
                ('Quantized operation({}) must be instance '
                 'of "torch.nn.Module", please replace {} with {}').format(
                     node.name, op, target_cls))

        # A quantized op must be implemented as a module.
        if not module:
          if node.op.type == OpTypes.INPUT:
            raise ValueError(
                ('Input is not quantized. Please use QuantStub/DeQuantStub to '
                 'define quantization scope.'))
          else:
            raise ValueError(
                ('Can not quantize node "{}({})" as it is not a '
                 'torch.nn.Module object, please re-implement this operation '
                 'as a module.').format(node.name, node.op.type))

        torch_op_type = py_utils.get_torch_op_type(node.op.type)
        torch_op_attr = py_utils.get_torch_op_attr(torch_op_type)
        if not torch_op_attr.op_name.startswith('torch'):
          logging.vlog(1,
                       'Non-torch op found: {}'.format(torch_op_attr.op_name))
          continue

        # Check if we get the correct module.
        op_type_name = torch_op_attr.op_name.split('.')[-1]
        logging.vlog(
            1, '{}({}): {} vs. {}'.format(node.name, node.op.type,
                                          module_cls.__name__,
                                          torch_op_attr.op_name))
        if not module_cls.__module__.startswith(
            'pytorch_nndct') and module_cls.__name__ != op_type_name:
          raise ValueError(('{} is a quantized operation, please re-implement '
                            'your op as a nn.Module (Node: {})').format(
                                torch_op_attr.op_name, node_name))
Ejemplo n.º 3
-1
def _topo_node_name(node):
    module_name = mod_util.module_name_from_node(node)
    node_name = node if isinstance(node, str) else node.name
    # Use node name for non-module node so that
    # we can have a complete topology.
    return module_name if module_name else node_name