def _get_init_module_str(self, node): torch_op_type = py_utils.get_torch_op_type(node.op.type) # if torch_op_type in TORCH_UNSUPPORTED_NNDCTOPS: # return '' torch_op_attr = py_utils.get_torch_op_attr(torch_op_type) op_name, is_defined_op = self._get_op_name(torch_op_type) def _init_attrs_str(): attrs_str = "" if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE: attrs_str = self._to_map_str( self._get_module_attrs_map(node, torch_op_type, torch_op_attr.attrs)) if not is_defined_op: attrs_str = f"'{node.op.type}',{attrs_str}" if attrs_str else f"'{node.op.type}'" return attrs_str def _init_module_str(): attrs_str = _init_attrs_str() return 'self.{module_name} = {op_name}({attrs}) #{node_name}'.format( module_name=self._get_module_name(node), op_name=op_name, attrs=attrs_str, node_name=node.name) return _init_module_str()
def _get_forward_str(self, node): output_str = self._to_list_str(self._get_module_output(node)) torch_op_type = py_utils.get_torch_op_type(node.op.type) torch_op_attr = py_utils.get_torch_op_attr(torch_op_type) if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE: input_str = self._to_list_str(self._get_module_input(node)) forward_str = "{output} = self.{module_name}({input})".format( output=output_str, module_name=self._get_module_name(node), input=input_str) else: func_attrs = self._get_module_attrs_map(node, torch_op_type, torch_op_attr.attrs) self._infer_attrs(func_attrs) func_attrs_str = self._to_map_str(func_attrs) if torch_op_attr.op_class_type == TorchOpClassType.TENSOR: input = self._get_module_input(node)[0] func_attrs_str = f"input={input}, {func_attrs_str}" forward_str = "{output} = self.{module_name}({attrs})".format( output=output_str, module_name=self._get_module_name(node), attrs=func_attrs_str) return forward_str, output_str
def _get_init_module_str(self, node: Node) -> str: torch_op_type = py_utils.get_torch_op_type(node.op.type) torch_op_attr = py_utils.get_torch_op_attr(torch_op_type) if torch_op_attr.op_class_type != TorchOpClassType.UNKNOWN: op_name, attrs_str = self._init_op_and_attrs_str(node) return 'self.{module_name} = {op_name}({attrs}) #{node_name}'.format( module_name=self._get_module_name(node), op_name=op_name, attrs=attrs_str, node_name=node.name)
def _get_init_module_str(self, node: Node) -> str: torch_op_type = py_utils.get_torch_op_type(node.op.type) torch_op_attr = py_utils.get_torch_op_attr(torch_op_type) if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE: attrs_str = self._to_map_str( self._get_module_attrs_map(node, torch_op_type, torch_op_attr.attrs)) return 'self.{module_name} = {op_name}({attrs}) #{node_name}'.format( module_name=self._get_module_name(node), op_name=torch_op_attr.op_name, attrs=attrs_str, node_name=node.name)
def _get_forward_str(self, node: Node) -> str: output_str = self._to_list_str(self._get_module_output(node)) torch_op_type = py_utils.get_torch_op_type(node.op.type) torch_op_attr = py_utils.get_torch_op_attr(torch_op_type) if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE: input_str = self._to_list_str(self._get_module_input(node)) forward_str = "{output} = self.{module_name}({input})".format( output=output_str, module_name=self._get_module_name(node), input=input_str) elif (torch_op_attr.op_class_type == TorchOpClassType.NN_FUNCTION or torch_op_attr.op_class_type == TorchOpClassType.TORCH_FUNCTION or torch_op_attr.op_class_type == TorchOpClassType.NN_CORE_FUNCTION): func_attrs = self._get_module_attrs_map(node, torch_op_type, torch_op_attr.attrs) self._infer_attrs(func_attrs) forward_str = "{output} = {op_name}({attrs}) #{node_name}".format( output=output_str, op_name=torch_op_attr.op_name, attrs=self._to_map_str(func_attrs), node_name=node.name) elif torch_op_attr.op_class_type == TorchOpClassType.TENSOR: input = self._get_module_input(node)[0] func_attrs = self._get_module_attrs_map(node, torch_op_type, torch_op_attr.attrs) self._infer_attrs(func_attrs) if 'input' in func_attrs: del func_attrs['input'] forward_str = "{output} = {input}.{op_name}({attrs}) #{node_name}".format( output=output_str, input=input, op_name=torch_op_attr.op_name, attrs=self._to_map_str(func_attrs), node_name=node.name) elif node.op.type in MISC_OP_DISCR_MAP: forward_str = MISC_OP_DISCR_MAP[node.op.type](self, node, output_str) else: raise RuntimeError( 'op_class_type of op is unknown, please check the operation: {}' .format(node.op.type)) return forward_str, output_str
def _init_op_and_attrs_str(self, node: Node) -> str: torch_op_type = py_utils.get_torch_op_type(node.op.type) torch_op_attr = py_utils.get_torch_op_attr(torch_op_type) op_name = py_utils.get_defined_quant_module(torch_op_type) op_name, is_defined_op = (op_name, True) if op_name else (".".join( [TorchSymbol.MODULE_PREFIX, "Module"]), False) attrs_str = "" if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE: attrs_str = self._to_map_str( self._get_module_attrs_map(node, torch_op_type, torch_op_attr.attrs)) if not is_defined_op: attrs_str = f"'{node.op.type}',{attrs_str}" if attrs_str else f"'{node.op.type}'" return op_name, attrs_str
def _get_forward_str(self, node: Node) -> str: output_str = self._to_list_str(self._get_module_output(node)) torch_op_type = py_utils.get_torch_op_type(node.op.type) torch_op_attr = py_utils.get_torch_op_attr(torch_op_type) if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE: input_str = self._to_list_str(self._get_module_input(node)) forward_str = "{output} = self.{module_name}({input})".format( output=output_str, module_name=self._get_module_name(node), input=input_str) elif (torch_op_attr.op_class_type == TorchOpClassType.NN_FUNCTION or torch_op_attr.op_class_type == TorchOpClassType.TORCH_FUNCTION or torch_op_attr.op_class_type == TorchOpClassType.TENSOR or torch_op_attr.op_class_type == TorchOpClassType.PRIMITIVE): func_attrs = self._get_module_attrs_map(node, torch_op_type, torch_op_attr.attrs) self._infer_attrs(func_attrs) func_attrs_str = self._to_map_str(func_attrs) if torch_op_attr.op_class_type == TorchOpClassType.TENSOR: input = self._get_module_input(node)[0] func_attrs_str = f"input={input}, {func_attrs_str}" forward_str = "{output} = self.{module_name}({attrs})".format( output=output_str, module_name=self._get_module_name(node), attrs=func_attrs_str) elif torch_op_attr.op_class_type == TorchOpClassType.UNKNOWN and node.op.type in MISC_OP_DISCR_MAP: forward_str = MISC_OP_DISCR_MAP[node.op.type](self, node, output_str) else: raise RuntimeError( 'op_class_type of op is unknown, please check the operation: {}' .format(node.op.type)) return forward_str, output_str
def Module(nndct_type, *args, **kwargs): torch_op_type = py_utils.get_torch_op_type(nndct_type) torch_op_attr = py_utils.get_torch_op_attr(torch_op_type) return creat_module(torch_op_type, torch_op_attr, *args, **kwargs)
def _assert_valid_model(self, allow_reused_module): # If two or more nodes point to a same module, then we will let them # use the same qconfig. module_to_qconfig = {} for node in self._graph.nodes: module_name = mod_util.module_name_from_node(node) if not module_name or node.name not in self._node_to_qconfig: continue if module_name in module_to_qconfig: if allow_reused_module: self._node_to_qconfig[node.name] = module_to_qconfig[module_name] logging.warn( ('Reused module ({}) may lead to low accuracy of QAT, ' 'make sure this is what you expect.').format(module_name)) else: raise ValueError( ('Quantized module "{}" has been called multiple ' 'times in forward pass. If you want to share quantized ' 'parameters in multiple calls, call trainable_model with ' '"allow_reused_module=True"').format(module_name)) module_to_qconfig[module_name] = self._node_to_qconfig[node.name] # Make sure all quantizable operations are instance of torch.nn.Module. replacement_map = { OpTypes.ADD: ('torch.add/+', functional.Add), OpTypes.CONCAT: ('torch.cat', functional.Cat), OpTypes.MAX: ('torch.max', functional.Max), OpTypes.PAD: ('torch.nn.functional.pad', functional.Pad), OpTypes.RELU: ('torch.nn.functional.relu', torch.nn.ReLU), OpTypes.SUM: ('torch.sum', functional.Sum), } for name, group in self._quant_config.items(): if name not in self._qinfo_keys: continue for key in group: node_name, _ = self._quant_config[name][key] module = mod_util.get_module_by_node(self._model, node_name) node = self._graph.node(node_name) module_cls = type(module) if module else None if node.op.type in replacement_map: op, target_cls = replacement_map[node.op.type] if module_cls != target_cls: raise ValueError( ('Quantized operation({}) must be instance ' 'of "torch.nn.Module", please replace {} with {}').format( node.name, op, target_cls)) # A quantized op must be implemented as a module. if not module: if node.op.type == OpTypes.INPUT: raise ValueError( ('Input is not quantized. Please use QuantStub/DeQuantStub to ' 'define quantization scope.')) else: raise ValueError( ('Can not quantize node "{}({})" as it is not a ' 'torch.nn.Module object, please re-implement this operation ' 'as a module.').format(node.name, node.op.type)) torch_op_type = py_utils.get_torch_op_type(node.op.type) torch_op_attr = py_utils.get_torch_op_attr(torch_op_type) if not torch_op_attr.op_name.startswith('torch'): logging.vlog(1, 'Non-torch op found: {}'.format(torch_op_attr.op_name)) continue # Check if we get the correct module. op_type_name = torch_op_attr.op_name.split('.')[-1] logging.vlog( 1, '{}({}): {} vs. {}'.format(node.name, node.op.type, module_cls.__name__, torch_op_attr.op_name)) if not module_cls.__module__.startswith( 'pytorch_nndct') and module_cls.__name__ != op_type_name: raise ValueError(('{} is a quantized operation, please re-implement ' 'your op as a nn.Module (Node: {})').format( torch_op_attr.op_name, node_name))
def Module(nndct_type, *args, **kwargs): quant_mode, _ = maybe_get_quantizer() torch_op_type = py_utils.get_torch_op_type(nndct_type) torch_op_attr = py_utils.get_torch_op_attr(torch_op_type) return creat_module(torch_op_type, torch_op_attr, *args, **kwargs)
def _get_forward_str(self, node: Node) -> str: output_str = self._to_list_str(self._get_module_output(node)) torch_op_type = py_utils.get_torch_op_type(node.op.type) torch_op_attr = py_utils.get_torch_op_attr(torch_op_type) if torch_op_attr.op_class_type == TorchOpClassType.NN_MODULE: input_str = self._to_list_str(self._get_module_input(node)) forward_str = "{output} = self.{module_name}({input})".format( output=output_str, module_name=self._get_module_name(node), input=input_str) elif (torch_op_attr.op_class_type == TorchOpClassType.NN_FUNCTION or torch_op_attr.op_class_type == TorchOpClassType.TORCH_FUNCTION or torch_op_attr.op_class_type == TorchOpClassType.TENSOR or torch_op_attr.op_class_type == TorchOpClassType.PRIMITIVE or torch_op_attr.op_class_type == TorchOpClassType.NN_CORE_FUNCTION): func_attrs = self._get_module_attrs_map(node, torch_op_type, torch_op_attr.attrs) if not func_attrs: input_str = self._to_list_str(self._get_module_input(node)) forward_str = "{output} = self.{module_name}({input})".format( output=output_str, module_name=self._get_module_name(node), input=input_str) else: self._infer_attrs(func_attrs) func_attrs_str = self._to_map_str(func_attrs) if torch_op_attr.op_class_type == TorchOpClassType.TENSOR and "input" not in func_attrs: input = self._get_module_input(node)[0] func_attrs_str = f"input={input}, {func_attrs_str}" forward_str = "{output} = self.{module_name}({attrs})".format( output=output_str, module_name=self._get_module_name(node), attrs=func_attrs_str) elif torch_op_attr.op_class_type == TorchOpClassType.UNKNOWN and node.op.type in MISC_OP_DISCR_MAP: forward_str = MISC_OP_DISCR_MAP[node.op.type](self, node, output_str) elif torch_op_attr.op_class_type in [ TorchOpClassType.TORCH_SCRIPT_BUILTIN_FUNCTION, TorchOpClassType.MATH_BUILTIN_FUNCTION, TorchOpClassType.GLOBAL_BUILTIN_FUNCTION, TorchOpClassType.CUSTOM_FUNCTION ]: func_attrs = self._get_module_attrs_map(node, torch_op_type, torch_op_attr.attrs) if not func_attrs: input_str = self._to_list_str(self._get_module_input(node)) forward_str = "{output} = self.{module_name}({input})".format( output=output_str, module_name=self._get_module_name(node), input=input_str) else: self._infer_attrs(func_attrs) args = [arg_value for arg_value in func_attrs.values()] args_str = self._to_list_str(args) forward_str = "{output} = self.{module_name}({attrs})".format( output=output_str, module_name=self._get_module_name(node), attrs=args_str) else: raise RuntimeError( f'op_class_type of op ({torch_op_attr.op_class_type.value}) is unknown, please check the operation: {node.op.type}.' ) return forward_str, output_str