def checkGraphModeFxOp(self, model, inputs, quantized_node, quant_type=QuantType.STATIC, debug=False): """ Quantizes model with graph mode quantization on fx and check if the quantized model contains the quantized_node Args: model: floating point torch.nn.Module inputs: one positional sample input arguments for model quantized_node: a tuple of 2 elements, first element is the op type for GraphModule node, second element is the target function for call_function and type of the module for call_module """ # TODO: make img_data a single example instead of a list if type(inputs) == list: inputs = inputs[0] if quant_type == QuantType.QAT: model.train() else: model.eval() original = symbolic_trace(model) fused = fuse(original) quantizer = Quantizer() # TODO: uncommon after we make per channel observer work in the flow # qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} qconfig_dict = {'' : default_qconfig} if quant_type == QuantType.DYNAMIC: prepared = quantizer.prepare_dynamic(fused, qconfig_dict) else: prepared = quantizer.prepare(fused, qconfig_dict) prepared(*inputs) qgraph = quantizer.convert(prepared) qgraph_debug = quantizer.convert(prepared, debug=True) result = qgraph(*inputs) result_debug = qgraph_debug(*inputs) self.assertEqual((result - result_debug).abs().max(), 0), \ 'Expecting debug and non-debug option to produce identical result' if debug: print() print('quant type:', quant_type) print('origianl graph module:', type(model)) self.printGraphModule(original) print() print('quantized graph module:', type(qgraph)) self.printGraphModule(qgraph) print() self.checkGraphModuleHasNode(qgraph, quantized_node)
def _test_model_impl(self, mode, name, model, eager_quantizable_model, check_with_eager=True, diff_of_quant=None, diff_from_eager=None): if diff_of_quant is None or diff_from_eager is None: diff_of_quant = {} diff_from_eager = {} if mode not in diff_of_quant or mode not in diff_from_eager: diff_of_quant[mode] = {} diff_from_eager[mode] = {} input_tensor = torch.rand(1, 3, 224, 224) input_tensor_inception = torch.rand(1, 3, 299, 299) output_value = torch.randint(0, 1, (1, )) # print('quantizing:', name, ' mode:', mode) if name == 'inception_v3': input_value = input_tensor_inception else: input_value = input_tensor qconfig = default_qconfig if mode == 'static' else default_qat_qconfig qconfig_dict = {'': qconfig} graph_module = symbolic_trace(model) # print('graph module:', graph_module.src) script = torch.jit.script(graph_module) # make sure graph module and script module are both runanble original_out = graph_module(input_value) is_not_tuple_out = not isinstance(original_out, tuple) script_out = script(input_value) self.assertEqual( (original_out - script_out).abs().max(), 0, 'Reslut of original graph module and script module does not match') # set to train just before quantization if mode != 'static': model.train() graph_module = fuse(graph_module) quantizer = Quantizer() prepared = quantizer.prepare(graph_module, qconfig_dict) if mode == 'ddp': mp.spawn(run_ddp, args=(world_size, prepared), nprocs=world_size, join=True) elif mode == 'qat': assert prepared.training, 'prepared must be in training mode for qat' optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) criterion = nn.CrossEntropyLoss() train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) else: for i in range(10): prepared(input_value) # print('after observation root:', prepared.root) qgraph = quantizer.convert(prepared) # print('after quantization root:', qgraph.root) # print('after quantization code:', qgraph.src) qgraph.eval() qgraph_script = torch.jit.script(qgraph) # print('quantized and scripted:', qgraph_script.graph) qgraph_out = qgraph(input_value) qgraph_script = qgraph_script(input_value) if is_not_tuple_out: diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max() assert torch.allclose(qgraph_out, qgraph_script), 'graph, scripted graph' else: print('tuple output') if eager_quantizable_model is not None: # comparing to eager mode quantization qeager = eager_quantizable_model ref_out = qeager(input_value) qeager.qconfig = qconfig if mode == 'static': qeager.fuse_model() prepare(qeager, inplace=True) else: qeager.train() qeager.fuse_model() prepare_qat(qeager, inplace=True) # calibration if mode == 'ddp': mp.spawn(run_ddp, args=(world_size, qeager), nprocs=world_size, join=True) elif mode == 'qat': assert qeager.training, 'qeager should be in training mode for qat' optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001) train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) else: for i in range(10): qeager(input_value) # print('ref after observation:', qeager) convert(qeager, inplace=True) qeager.eval() # print('ref after quantization:', qeager) qeager_out = qeager(input_value) qeager_script = torch.jit.script(qeager) qscript_out = qeager_script(input_value) if is_not_tuple_out: diff_from_eager[mode][name] = (qeager_out - qgraph_out).abs().max() if check_with_eager: self.assertEqual( diff_from_eager[mode][name], 0, 'Result of graph mode quantization and ' + 'eager mode quantization on model: ' + name + ' should match. Mode: ' + mode + ' diff:' + str(diff_from_eager[mode][name]))
def checkGraphModeFxOp(self, model, inputs, quant_type, expected_node=None, expected_node_occurrence=None, expected_node_list=None, debug=False): """ Quantizes model with graph mode quantization on fx and check if the quantized model contains the quantized_node Args: model: floating point torch.nn.Module inputs: one positional sample input arguments for model expected_node: NodeSpec e.g. NodeSpec.call_function(torch.quantize_per_tensor) expected_node_occurrence: a dict from NodeSpec to expected number of occurences (int) e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, NodeSpec.call_method('dequantize'): 1} expected_node_list: a list of NodeSpec, used to check the order of the occurrence of Node e.g. [NodeSpec.call_function(torch.quantize_per_tensor), NodeSpec.call_module(nnq.Conv2d), NodeSpec.call_function(F.hardtanh_), NodeSpec.call_method('dequantize')] """ # TODO: make img_data a single example instead of a list if type(inputs) == list: inputs = inputs[0] if quant_type == QuantType.QAT: model.train() else: model.eval() original = symbolic_trace(model) fused = fuse(original) quantizer = Quantizer() qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} if quant_type == QuantType.DYNAMIC: prepared = quantizer.prepare_dynamic(fused, qconfig_dict) else: prepared = quantizer.prepare(fused, qconfig_dict) prepared(*inputs) qgraph = quantizer.convert(prepared) qgraph_debug = quantizer.convert(prepared, debug=True) result = qgraph(*inputs) result_debug = qgraph_debug(*inputs) self.assertEqual((result - result_debug).abs().max(), 0), \ 'Expecting debug and non-debug option to produce identical result' if debug: print() print('quant type:', quant_type) print('origianl graph module:', type(model)) self.printGraphModule(original) print() print('quantized graph module:', type(qgraph)) self.printGraphModule(qgraph) print() self.checkGraphModuleNodes( qgraph, expected_node, expected_node_occurrence, expected_node_list)
def test_functional(self): """ Test quantizing functional conv and linear """ class Conv(torch.nn.Module): def __init__(self): super().__init__() self.stride = (1, 1) self.padding = (0, 0) self.dilation = (1, 1) self.groups = 1 def forward(self, x, weight): return F.conv2d(x, weight, None, self.stride, self.padding, self.dilation, self.groups) conv_input = torch.rand(1, 3, 224, 224) conv_weight = torch.rand(3, 3, 3, 3) class Linear(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, weight): return F.linear(x, weight) linear_input = torch.rand(8, 5) linear_weight = torch.rand(10, 5) class LinearModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) linear_module_input = torch.rand(8, 5) tests = [ (False, Conv, (conv_input, conv_weight), ('call_function', torch.ops.quantized.conv2d)), (True, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear_dynamic)), (False, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear)), (True, LinearModule, (linear_module_input, ), ('call_module', torch.nn.quantized.dynamic.Linear)), (False, LinearModule, (linear_module_input, ), ('call_module', torch.nn.quantized.Linear)), ] for is_dynamic, M, inputs, quantized_node in tests: m = M().eval() qconfig = default_qconfig graph = symbolic_trace(m) script = torch.jit.script(graph) a = m(*inputs) b = graph(*inputs) c = script(*inputs) assert (a - b).abs().max() == 0 assert (a - c).abs().max() == 0 assert torch.allclose(a, b) assert torch.allclose(a, c) graph = fuse(graph) quantizer = Quantizer() qconfig_dict = {'': qconfig} if is_dynamic: prepared = quantizer.prepare_dynamic(graph, qconfig_dict) else: prepared = quantizer.prepare(graph, qconfig_dict) prepared(*inputs) qgraph = quantizer.convert(prepared) qgraph_debug = quantizer.convert(prepared, debug=True) qgraph.eval() qgraph_debug.eval() qgraph_script = torch.jit.script(qgraph) d = qgraph(*inputs) d_debug = qgraph_debug(*inputs) e = qgraph_script(*inputs) e_debug = qgraph_debug(*inputs) found = False modules = dict(qgraph.root.named_modules()) for node in qgraph.graph.nodes: if node.op == 'call_function': found = found or node.op == quantized_node[ 0] and node.target == quantized_node[1] elif node.op == 'call_module': found = found or node.op == quantized_node[0] and type( modules[node.target]) == quantized_node[1] assert found, 'Expected to find quantized node:' + str( quantized_op) # assert (a-d).abs().max() < 2 assert torch.allclose(d, e) assert (d - d_debug).abs().max() == 0 assert (e - e_debug).abs().max() == 0