def _get_init_module_str(self, node):
        torch_op_type = py_utils.get_torch_op_type(node.op.type)

        # if torch_op_type in TORCH_UNSUPPORTED_NNDCTOPS:
        #   return ''

        torch_op_attr = py_utils.get_torch_op_attr(torch_op_type)
        op_name, is_defined_op = self._get_op_name(torch_op_type)

        def _init_attrs_str():
            attrs_str = ""
            if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE:
                attrs_str = self._to_map_str(
                    self._get_module_attrs_map(node, torch_op_type,
                                               torch_op_attr.attrs))

            if not is_defined_op:
                attrs_str = f"'{node.op.type}',{attrs_str}" if attrs_str else f"'{node.op.type}'"

            return attrs_str

        def _init_module_str():
            attrs_str = _init_attrs_str()
            return 'self.{module_name} = {op_name}({attrs}) #{node_name}'.format(
                module_name=self._get_module_name(node),
                op_name=op_name,
                attrs=attrs_str,
                node_name=node.name)

        return _init_module_str()
    def _get_forward_str(self, node):
        output_str = self._to_list_str(self._get_module_output(node))

        torch_op_type = py_utils.get_torch_op_type(node.op.type)
        torch_op_attr = py_utils.get_torch_op_attr(torch_op_type)

        if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE:
            input_str = self._to_list_str(self._get_module_input(node))
            forward_str = "{output} = self.{module_name}({input})".format(
                output=output_str,
                module_name=self._get_module_name(node),
                input=input_str)
        else:
            func_attrs = self._get_module_attrs_map(node, torch_op_type,
                                                    torch_op_attr.attrs)

            self._infer_attrs(func_attrs)
            func_attrs_str = self._to_map_str(func_attrs)
            if torch_op_attr.op_class_type == TorchOpClassType.TENSOR:
                input = self._get_module_input(node)[0]
                func_attrs_str = f"input={input}, {func_attrs_str}"

            forward_str = "{output} = self.{module_name}({attrs})".format(
                output=output_str,
                module_name=self._get_module_name(node),
                attrs=func_attrs_str)

        return forward_str, output_str
Beispiel #3
0
    def _get_init_module_str(self, node: Node) -> str:
        torch_op_type = py_utils.get_torch_op_type(node.op.type)
        torch_op_attr = py_utils.get_torch_op_attr(torch_op_type)
        if torch_op_attr.op_class_type != TorchOpClassType.UNKNOWN:
            op_name, attrs_str = self._init_op_and_attrs_str(node)

            return 'self.{module_name} = {op_name}({attrs}) #{node_name}'.format(
                module_name=self._get_module_name(node),
                op_name=op_name,
                attrs=attrs_str,
                node_name=node.name)
Beispiel #4
0
 def _get_init_module_str(self, node: Node) -> str:
     torch_op_type = py_utils.get_torch_op_type(node.op.type)
     torch_op_attr = py_utils.get_torch_op_attr(torch_op_type)
     if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE:
         attrs_str = self._to_map_str(
             self._get_module_attrs_map(node, torch_op_type,
                                        torch_op_attr.attrs))
         return 'self.{module_name} = {op_name}({attrs}) #{node_name}'.format(
             module_name=self._get_module_name(node),
             op_name=torch_op_attr.op_name,
             attrs=attrs_str,
             node_name=node.name)
Beispiel #5
0
    def _get_forward_str(self, node: Node) -> str:
        output_str = self._to_list_str(self._get_module_output(node))
        torch_op_type = py_utils.get_torch_op_type(node.op.type)
        torch_op_attr = py_utils.get_torch_op_attr(torch_op_type)

        if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE:
            input_str = self._to_list_str(self._get_module_input(node))
            forward_str = "{output} = self.{module_name}({input})".format(
                output=output_str,
                module_name=self._get_module_name(node),
                input=input_str)

        elif (torch_op_attr.op_class_type == TorchOpClassType.NN_FUNCTION
              or torch_op_attr.op_class_type == TorchOpClassType.TORCH_FUNCTION
              or torch_op_attr.op_class_type
              == TorchOpClassType.NN_CORE_FUNCTION):
            func_attrs = self._get_module_attrs_map(node, torch_op_type,
                                                    torch_op_attr.attrs)
            self._infer_attrs(func_attrs)
            forward_str = "{output} = {op_name}({attrs}) #{node_name}".format(
                output=output_str,
                op_name=torch_op_attr.op_name,
                attrs=self._to_map_str(func_attrs),
                node_name=node.name)

        elif torch_op_attr.op_class_type == TorchOpClassType.TENSOR:
            input = self._get_module_input(node)[0]
            func_attrs = self._get_module_attrs_map(node, torch_op_type,
                                                    torch_op_attr.attrs)
            self._infer_attrs(func_attrs)
            if 'input' in func_attrs:
                del func_attrs['input']

            forward_str = "{output} = {input}.{op_name}({attrs}) #{node_name}".format(
                output=output_str,
                input=input,
                op_name=torch_op_attr.op_name,
                attrs=self._to_map_str(func_attrs),
                node_name=node.name)

        elif node.op.type in MISC_OP_DISCR_MAP:
            forward_str = MISC_OP_DISCR_MAP[node.op.type](self, node,
                                                          output_str)

        else:
            raise RuntimeError(
                'op_class_type of op is unknown, please check the operation: {}'
                .format(node.op.type))

        return forward_str, output_str
Beispiel #6
0
    def _init_op_and_attrs_str(self, node: Node) -> str:
        torch_op_type = py_utils.get_torch_op_type(node.op.type)
        torch_op_attr = py_utils.get_torch_op_attr(torch_op_type)
        op_name = py_utils.get_defined_quant_module(torch_op_type)
        op_name, is_defined_op = (op_name, True) if op_name else (".".join(
            [TorchSymbol.MODULE_PREFIX, "Module"]), False)
        attrs_str = ""
        if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE:
            attrs_str = self._to_map_str(
                self._get_module_attrs_map(node, torch_op_type,
                                           torch_op_attr.attrs))

        if not is_defined_op:
            attrs_str = f"'{node.op.type}',{attrs_str}" if attrs_str else f"'{node.op.type}'"

        return op_name, attrs_str
Beispiel #7
0
    def _get_forward_str(self, node: Node) -> str:
        output_str = self._to_list_str(self._get_module_output(node))

        torch_op_type = py_utils.get_torch_op_type(node.op.type)
        torch_op_attr = py_utils.get_torch_op_attr(torch_op_type)

        if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE:
            input_str = self._to_list_str(self._get_module_input(node))
            forward_str = "{output} = self.{module_name}({input})".format(
                output=output_str,
                module_name=self._get_module_name(node),
                input=input_str)
        elif (torch_op_attr.op_class_type == TorchOpClassType.NN_FUNCTION
              or torch_op_attr.op_class_type == TorchOpClassType.TORCH_FUNCTION
              or torch_op_attr.op_class_type == TorchOpClassType.TENSOR
              or torch_op_attr.op_class_type == TorchOpClassType.PRIMITIVE):
            func_attrs = self._get_module_attrs_map(node, torch_op_type,
                                                    torch_op_attr.attrs)

            self._infer_attrs(func_attrs)
            func_attrs_str = self._to_map_str(func_attrs)
            if torch_op_attr.op_class_type == TorchOpClassType.TENSOR:
                input = self._get_module_input(node)[0]
                func_attrs_str = f"input={input}, {func_attrs_str}"

            forward_str = "{output} = self.{module_name}({attrs})".format(
                output=output_str,
                module_name=self._get_module_name(node),
                attrs=func_attrs_str)

        elif torch_op_attr.op_class_type == TorchOpClassType.UNKNOWN and node.op.type in MISC_OP_DISCR_MAP:
            forward_str = MISC_OP_DISCR_MAP[node.op.type](self, node,
                                                          output_str)

        else:
            raise RuntimeError(
                'op_class_type of op is unknown, please check the operation: {}'
                .format(node.op.type))

        return forward_str, output_str
Beispiel #8
0
def Module(nndct_type, *args, **kwargs):
    torch_op_type = py_utils.get_torch_op_type(nndct_type)
    torch_op_attr = py_utils.get_torch_op_attr(torch_op_type)
    return creat_module(torch_op_type, torch_op_attr, *args, **kwargs)
Beispiel #9
0
  def _assert_valid_model(self, allow_reused_module):
    # If two or more nodes point to a same module, then we will let them
    # use the same qconfig.
    module_to_qconfig = {}
    for node in self._graph.nodes:
      module_name = mod_util.module_name_from_node(node)
      if not module_name or node.name not in self._node_to_qconfig:
        continue

      if module_name in module_to_qconfig:
        if allow_reused_module:
          self._node_to_qconfig[node.name] = module_to_qconfig[module_name]
          logging.warn(
              ('Reused module ({}) may lead to low accuracy of QAT, '
               'make sure this is what you expect.').format(module_name))
        else:
          raise ValueError(
              ('Quantized module "{}" has been called multiple '
               'times in forward pass. If you want to share quantized '
               'parameters in multiple calls, call trainable_model with '
               '"allow_reused_module=True"').format(module_name))
      module_to_qconfig[module_name] = self._node_to_qconfig[node.name]

    # Make sure all quantizable operations are instance of torch.nn.Module.
    replacement_map = {
        OpTypes.ADD: ('torch.add/+', functional.Add),
        OpTypes.CONCAT: ('torch.cat', functional.Cat),
        OpTypes.MAX: ('torch.max', functional.Max),
        OpTypes.PAD: ('torch.nn.functional.pad', functional.Pad),
        OpTypes.RELU: ('torch.nn.functional.relu', torch.nn.ReLU),
        OpTypes.SUM: ('torch.sum', functional.Sum),
    }

    for name, group in self._quant_config.items():
      if name not in self._qinfo_keys:
        continue
      for key in group:
        node_name, _ = self._quant_config[name][key]

        module = mod_util.get_module_by_node(self._model, node_name)
        node = self._graph.node(node_name)

        module_cls = type(module) if module else None
        if node.op.type in replacement_map:
          op, target_cls = replacement_map[node.op.type]
          if module_cls != target_cls:
            raise ValueError(
                ('Quantized operation({}) must be instance '
                 'of "torch.nn.Module", please replace {} with {}').format(
                     node.name, op, target_cls))

        # A quantized op must be implemented as a module.
        if not module:
          if node.op.type == OpTypes.INPUT:
            raise ValueError(
                ('Input is not quantized. Please use QuantStub/DeQuantStub to '
                 'define quantization scope.'))
          else:
            raise ValueError(
                ('Can not quantize node "{}({})" as it is not a '
                 'torch.nn.Module object, please re-implement this operation '
                 'as a module.').format(node.name, node.op.type))

        torch_op_type = py_utils.get_torch_op_type(node.op.type)
        torch_op_attr = py_utils.get_torch_op_attr(torch_op_type)
        if not torch_op_attr.op_name.startswith('torch'):
          logging.vlog(1,
                       'Non-torch op found: {}'.format(torch_op_attr.op_name))
          continue

        # Check if we get the correct module.
        op_type_name = torch_op_attr.op_name.split('.')[-1]
        logging.vlog(
            1, '{}({}): {} vs. {}'.format(node.name, node.op.type,
                                          module_cls.__name__,
                                          torch_op_attr.op_name))
        if not module_cls.__module__.startswith(
            'pytorch_nndct') and module_cls.__name__ != op_type_name:
          raise ValueError(('{} is a quantized operation, please re-implement '
                            'your op as a nn.Module (Node: {})').format(
                                torch_op_attr.op_name, node_name))
Beispiel #10
0
def Module(nndct_type, *args, **kwargs):
    quant_mode, _ = maybe_get_quantizer()
    torch_op_type = py_utils.get_torch_op_type(nndct_type)
    torch_op_attr = py_utils.get_torch_op_attr(torch_op_type)
    return creat_module(torch_op_type, torch_op_attr, *args, **kwargs)
Beispiel #11
0
    def _get_forward_str(self, node: Node) -> str:
        output_str = self._to_list_str(self._get_module_output(node))

        torch_op_type = py_utils.get_torch_op_type(node.op.type)
        torch_op_attr = py_utils.get_torch_op_attr(torch_op_type)

        if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE:
            input_str = self._to_list_str(self._get_module_input(node))
            forward_str = "{output} = self.{module_name}({input})".format(
                output=output_str,
                module_name=self._get_module_name(node),
                input=input_str)
        elif (torch_op_attr.op_class_type == TorchOpClassType.NN_FUNCTION
              or torch_op_attr.op_class_type == TorchOpClassType.TORCH_FUNCTION
              or torch_op_attr.op_class_type == TorchOpClassType.TENSOR
              or torch_op_attr.op_class_type == TorchOpClassType.PRIMITIVE
              or torch_op_attr.op_class_type
              == TorchOpClassType.NN_CORE_FUNCTION):
            func_attrs = self._get_module_attrs_map(node, torch_op_type,
                                                    torch_op_attr.attrs)
            if not func_attrs:
                input_str = self._to_list_str(self._get_module_input(node))
                forward_str = "{output} = self.{module_name}({input})".format(
                    output=output_str,
                    module_name=self._get_module_name(node),
                    input=input_str)
            else:
                self._infer_attrs(func_attrs)
                func_attrs_str = self._to_map_str(func_attrs)
                if torch_op_attr.op_class_type == TorchOpClassType.TENSOR and "input" not in func_attrs:
                    input = self._get_module_input(node)[0]
                    func_attrs_str = f"input={input}, {func_attrs_str}"

                forward_str = "{output} = self.{module_name}({attrs})".format(
                    output=output_str,
                    module_name=self._get_module_name(node),
                    attrs=func_attrs_str)

        elif torch_op_attr.op_class_type == TorchOpClassType.UNKNOWN and node.op.type in MISC_OP_DISCR_MAP:
            forward_str = MISC_OP_DISCR_MAP[node.op.type](self, node,
                                                          output_str)

        elif torch_op_attr.op_class_type in [
                TorchOpClassType.TORCH_SCRIPT_BUILTIN_FUNCTION,
                TorchOpClassType.MATH_BUILTIN_FUNCTION,
                TorchOpClassType.GLOBAL_BUILTIN_FUNCTION,
                TorchOpClassType.CUSTOM_FUNCTION
        ]:
            func_attrs = self._get_module_attrs_map(node, torch_op_type,
                                                    torch_op_attr.attrs)
            if not func_attrs:
                input_str = self._to_list_str(self._get_module_input(node))
                forward_str = "{output} = self.{module_name}({input})".format(
                    output=output_str,
                    module_name=self._get_module_name(node),
                    input=input_str)
            else:
                self._infer_attrs(func_attrs)
                args = [arg_value for arg_value in func_attrs.values()]
                args_str = self._to_list_str(args)
                forward_str = "{output} = self.{module_name}({attrs})".format(
                    output=output_str,
                    module_name=self._get_module_name(node),
                    attrs=args_str)

        else:
            raise RuntimeError(
                f'op_class_type of op ({torch_op_attr.op_class_type.value}) is unknown, please check the operation: {node.op.type}.'
            )

        return forward_str, output_str