def insert_quantizer(model, node_to_qconfig): for node, qconfig in node_to_qconfig.items(): if not qconfig.input and not qconfig.output: continue module_name = mod_util.module_name_from_node(node) if not module_name: raise ValueError(('Can not find module for node "{}"' 'Only module object can be quantized,' 'please re-implement this operation as a module.' ).format(node)) module = mod_util.get_module(model, module_name) if qconfig.input: # Reserved support for multiple inputs, currently will always be 0. quantize_input(module, 0, qconfig.input) if qconfig.output: quantize_output(module, qconfig.output)
def _assert_valid_model(self, allow_reused_module): # If two or more nodes point to a same module, then we will let them # use the same qconfig. module_to_qconfig = {} for node in self._graph.nodes: module_name = mod_util.module_name_from_node(node) if not module_name or node.name not in self._node_to_qconfig: continue if module_name in module_to_qconfig: if allow_reused_module: self._node_to_qconfig[node.name] = module_to_qconfig[module_name] logging.warn( ('Reused module ({}) may lead to low accuracy of QAT, ' 'make sure this is what you expect.').format(module_name)) else: raise ValueError( ('Quantized module "{}" has been called multiple ' 'times in forward pass. If you want to share quantized ' 'parameters in multiple calls, call trainable_model with ' '"allow_reused_module=True"').format(module_name)) module_to_qconfig[module_name] = self._node_to_qconfig[node.name] # Make sure all quantizable operations are instance of torch.nn.Module. replacement_map = { OpTypes.ADD: ('torch.add/+', functional.Add), OpTypes.CONCAT: ('torch.cat', functional.Cat), OpTypes.MAX: ('torch.max', functional.Max), OpTypes.PAD: ('torch.nn.functional.pad', functional.Pad), OpTypes.RELU: ('torch.nn.functional.relu', torch.nn.ReLU), OpTypes.SUM: ('torch.sum', functional.Sum), } for name, group in self._quant_config.items(): if name not in self._qinfo_keys: continue for key in group: node_name, _ = self._quant_config[name][key] module = mod_util.get_module_by_node(self._model, node_name) node = self._graph.node(node_name) module_cls = type(module) if module else None if node.op.type in replacement_map: op, target_cls = replacement_map[node.op.type] if module_cls != target_cls: raise ValueError( ('Quantized operation({}) must be instance ' 'of "torch.nn.Module", please replace {} with {}').format( node.name, op, target_cls)) # A quantized op must be implemented as a module. if not module: if node.op.type == OpTypes.INPUT: raise ValueError( ('Input is not quantized. Please use QuantStub/DeQuantStub to ' 'define quantization scope.')) else: raise ValueError( ('Can not quantize node "{}({})" as it is not a ' 'torch.nn.Module object, please re-implement this operation ' 'as a module.').format(node.name, node.op.type)) torch_op_type = py_utils.get_torch_op_type(node.op.type) torch_op_attr = py_utils.get_torch_op_attr(torch_op_type) if not torch_op_attr.op_name.startswith('torch'): logging.vlog(1, 'Non-torch op found: {}'.format(torch_op_attr.op_name)) continue # Check if we get the correct module. op_type_name = torch_op_attr.op_name.split('.')[-1] logging.vlog( 1, '{}({}): {} vs. {}'.format(node.name, node.op.type, module_cls.__name__, torch_op_attr.op_name)) if not module_cls.__module__.startswith( 'pytorch_nndct') and module_cls.__name__ != op_type_name: raise ValueError(('{} is a quantized operation, please re-implement ' 'your op as a nn.Module (Node: {})').format( torch_op_attr.op_name, node_name))
def _topo_node_name(node): module_name = mod_util.module_name_from_node(node) node_name = node if isinstance(node, str) else node.name # Use node name for non-module node so that # we can have a complete topology. return module_name if module_name else node_name