def checkGraphModeFxOp(self, model, inputs, quantized_node, quant_type=QuantType.STATIC, debug=False): """ Quantizes model with graph mode quantization on fx and check if the quantized model contains the quantized_node Args: model: floating point torch.nn.Module inputs: one positional sample input arguments for model quantized_node: a tuple of 2 elements, first element is the op type for GraphModule node, second element is the target function for call_function and type of the module for call_module """ # TODO: make img_data a single example instead of a list if type(inputs) == list: inputs = inputs[0] if quant_type == QuantType.QAT: model.train() else: model.eval() original = symbolic_trace(model) fused = fuse(original) quantizer = Quantizer() # TODO: uncommon after we make per channel observer work in the flow # qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} qconfig_dict = {'' : default_qconfig} if quant_type == QuantType.DYNAMIC: prepared = quantizer.prepare_dynamic(fused, qconfig_dict) else: prepared = quantizer.prepare(fused, qconfig_dict) prepared(*inputs) qgraph = quantizer.convert(prepared) qgraph_debug = quantizer.convert(prepared, debug=True) result = qgraph(*inputs) result_debug = qgraph_debug(*inputs) self.assertEqual((result - result_debug).abs().max(), 0), \ 'Expecting debug and non-debug option to produce identical result' if debug: print() print('quant type:', quant_type) print('origianl graph module:', type(model)) self.printGraphModule(original) print() print('quantized graph module:', type(qgraph)) self.printGraphModule(qgraph) print() self.checkGraphModuleHasNode(qgraph, quantized_node)
def _test_model_impl(self, mode, name, model, eager_quantizable_model, check_with_eager=True, diff_of_quant=None, diff_from_eager=None): if diff_of_quant is None or diff_from_eager is None: diff_of_quant = {} diff_from_eager = {} if mode not in diff_of_quant or mode not in diff_from_eager: diff_of_quant[mode] = {} diff_from_eager[mode] = {} input_tensor = torch.rand(1, 3, 224, 224) input_tensor_inception = torch.rand(1, 3, 299, 299) output_value = torch.randint(0, 1, (1, )) # print('quantizing:', name, ' mode:', mode) if name == 'inception_v3': input_value = input_tensor_inception else: input_value = input_tensor qconfig = default_qconfig if mode == 'static' else default_qat_qconfig qconfig_dict = {'': qconfig} graph_module = symbolic_trace(model) # print('graph module:', graph_module.src) script = torch.jit.script(graph_module) # make sure graph module and script module are both runanble original_out = graph_module(input_value) is_not_tuple_out = not isinstance(original_out, tuple) script_out = script(input_value) self.assertEqual( (original_out - script_out).abs().max(), 0, 'Reslut of original graph module and script module does not match') # set to train just before quantization if mode != 'static': model.train() graph_module = fuse(graph_module) quantizer = Quantizer() prepared = quantizer.prepare(graph_module, qconfig_dict) if mode == 'ddp': mp.spawn(run_ddp, args=(world_size, prepared), nprocs=world_size, join=True) elif mode == 'qat': assert prepared.training, 'prepared must be in training mode for qat' optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) criterion = nn.CrossEntropyLoss() train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) else: for i in range(10): prepared(input_value) # print('after observation root:', prepared.root) qgraph = quantizer.convert(prepared) # print('after quantization root:', qgraph.root) # print('after quantization code:', qgraph.src) qgraph.eval() qgraph_script = torch.jit.script(qgraph) # print('quantized and scripted:', qgraph_script.graph) qgraph_out = qgraph(input_value) qgraph_script = qgraph_script(input_value) if is_not_tuple_out: diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max() assert torch.allclose(qgraph_out, qgraph_script), 'graph, scripted graph' else: print('tuple output') if eager_quantizable_model is not None: # comparing to eager mode quantization qeager = eager_quantizable_model ref_out = qeager(input_value) qeager.qconfig = qconfig if mode == 'static': qeager.fuse_model() prepare(qeager, inplace=True) else: qeager.train() qeager.fuse_model() prepare_qat(qeager, inplace=True) # calibration if mode == 'ddp': mp.spawn(run_ddp, args=(world_size, qeager), nprocs=world_size, join=True) elif mode == 'qat': assert qeager.training, 'qeager should be in training mode for qat' optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001) train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) else: for i in range(10): qeager(input_value) # print('ref after observation:', qeager) convert(qeager, inplace=True) qeager.eval() # print('ref after quantization:', qeager) qeager_out = qeager(input_value) qeager_script = torch.jit.script(qeager) qscript_out = qeager_script(input_value) if is_not_tuple_out: diff_from_eager[mode][name] = (qeager_out - qgraph_out).abs().max() if check_with_eager: self.assertEqual( diff_from_eager[mode][name], 0, 'Result of graph mode quantization and ' + 'eager mode quantization on model: ' + name + ' should match. Mode: ' + mode + ' diff:' + str(diff_from_eager[mode][name]))
def test_general_value_ops(self): """ A test that checks correct patterns are produced for all supported general value ops like aten::avg_pool2d \ without actually checking for execution of these ops """ class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.avg_pool1d = torch.nn.AvgPool1d(3) self.avg_pool2d = torch.nn.AvgPool2d(3) self.avg_pool3d = torch.nn.AvgPool3d(3) self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1)) self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d( (1, 1, 1)) self.leaky_relu = torch.nn.LeakyReLU() self.hardsigmoid = torch.nn.Hardsigmoid() self.sigmoid = torch.nn.Sigmoid() self.tanh = torch.nn.Tanh() def forward(self, x): x = self.conv(x) x = self.avg_pool1d(x) x = self.avg_pool2d(x) x = self.avg_pool3d(x) x = self.adaptive_avg_pool1d(x) x = self.adaptive_avg_pool2d(x) x = self.adaptive_avg_pool3d(x) x = F.avg_pool1d(x, 3) x = F.avg_pool2d(x, 3) x = F.avg_pool3d(x, 3) x = F.adaptive_avg_pool1d(x, (1)) x = F.adaptive_avg_pool2d(x, (1, 1)) x = F.adaptive_avg_pool3d(x, (1, 1, 1)) x = torch.mean(x) x = torch.mean(x, [2, 3], False) x = x.mean() x = x.mean([2, 3], True) x = F.interpolate(x, 4, mode='nearest') x = F.interpolate(x, 4, mode='linear') x = self.leaky_relu(x) x = F.leaky_relu(x) x = F.leaky_relu(x, inplace=True) x = x.leaky_relu() x.leaky_relu_() x = self.hardsigmoid(x) x = F.hardsigmoid(x) x = F.hardsigmoid(x, inplace=True) x = x.hardsigmoid() x.hardsigmoid_() x = self.sigmoid(x) x = torch.sigmoid(x) # F.sigmoid is deprecated x = x.sigmoid() x.sigmoid_() x = self.tanh(x) # F.tanh is deprecated x = torch.tanh(x) x = x.tanh() x.tanh_() x = self.conv(x) return x # This model is not executable since we just put all ops # in the same forward m = M() original = symbolic_trace(m) # nothing to fuse so skipping the fuse step quantizer = Quantizer() qconfig_dict = {'': default_qconfig} prepared = quantizer.prepare(original, qconfig_dict) # not runnable quantized = quantizer.convert(prepared) # This checks that the dequantize from the output of first conv # is being propagated to the end, so that we don't insert extra # observers # check exact counts of quantize and dequantize count_check = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method('dequantize'): 1 } order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize'), ] self.checkGraphModuleNodes(quantized, expected_node_occurrence=count_check, expected_node_list=order_check)
def test_general_shape_ops(self): """ A test that checks dequantize will be swapped for all supported general shape ops like aten::flatten without actually checking for execution of these ops """ class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3) self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3) self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3) self.dropout = torch.nn.Dropout() self.conv1 = torch.nn.Conv2d(3, 3, 3) self.conv2 = torch.nn.Conv2d(3, 3, 3) self.relu = torch.nn.ReLU() def forward(self, x): x = self.conv1(x) # add_scalar x = x + 3 # mul_scalar x = x * 3 # add_scalar_out x += 3 # mul_scalar_out x *= 3 # add_scalar_relu x = x + 3 x = F.relu(x) # add_scalar_relu_out x += 3 x = F.relu(x) # mul_scalar_relu x = x * 3 x = F.relu(x) # mul_scalar_relu_out x *= 3 x = F.relu(x) x = self.maxpool1d(x) x = self.maxpool2d(x) x = self.maxpool3d(x) x = torch.flatten(x) x = torch.max(x) x = torch.min(x) x = x.reshape([-1]) x = x.resize_(1, 1, x.numel()) x = x.view(-1) # prim::ListConstruct xs = [x, x] # prim::ListUnpack x, y = xs # prim::TupleConstruct xs = (x, x) # prim::TupleUnpack x, y = xs x = x.transpose(1, 2) x = x.contiguous() x, y = torch.chunk(x, 2) x = F.dropout(x) x = self.dropout(x) x, _ = torch.sort(x) x = x.permute(0, 2, 3, 1) x = x.repeat_interleave(3, 1) x = torch.repeat_interleave(x, 3, 1) x = self.relu(x) x = F.relu(x) x = F.relu(x, inplace=True) x = x.relu() x.relu_() x = x.squeeze(0) x.squeeze_(0) x = torch.squeeze(x, 0) x = x.unsqueeze(0) x.unsqueeze_(0) x = torch.unsqueeze(x, 0) x = x.detach() x.detach_() x = x.repeat(4, 2) y = [] y.append(x) z = torch.stack(y, 0) z = [z, z] x, _ = z x = self.conv2(x) return x data = torch.rand(1, 3, 10, 10) # This model is not executable since we just put all ops # in the same forward m = M() original = symbolic_trace(m) # nothing to fuse so skipping the fuse step quantizer = Quantizer() qconfig_dict = {'': default_qconfig} prepared = quantizer.prepare(original, qconfig_dict) # not runnable quantized = quantizer.convert(prepared) # This checks that the dequantize from the output of first conv # is being propagated to the end, so that we don't insert extra # observers and also successfully fused two quantized::conv2d # patterns # one quantize_per_tensor for input # check exact counts of quantize and dequantize count_check = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method('dequantize'): 1 } order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize'), ] self.checkGraphModuleNodes(quantized, expected_node_occurrence=count_check, expected_node_list=order_check)
def checkGraphModeFxOp(self, model, inputs, quant_type, expected_node=None, expected_node_occurrence=None, expected_node_list=None, debug=False): """ Quantizes model with graph mode quantization on fx and check if the quantized model contains the quantized_node Args: model: floating point torch.nn.Module inputs: one positional sample input arguments for model expected_node: NodeSpec e.g. NodeSpec.call_function(torch.quantize_per_tensor) expected_node_occurrence: a dict from NodeSpec to expected number of occurences (int) e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, NodeSpec.call_method('dequantize'): 1} expected_node_list: a list of NodeSpec, used to check the order of the occurrence of Node e.g. [NodeSpec.call_function(torch.quantize_per_tensor), NodeSpec.call_module(nnq.Conv2d), NodeSpec.call_function(F.hardtanh_), NodeSpec.call_method('dequantize')] """ # TODO: make img_data a single example instead of a list if type(inputs) == list: inputs = inputs[0] if quant_type == QuantType.QAT: model.train() else: model.eval() original = symbolic_trace(model) fused = fuse(original) quantizer = Quantizer() qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} if quant_type == QuantType.DYNAMIC: prepared = quantizer.prepare_dynamic(fused, qconfig_dict) else: prepared = quantizer.prepare(fused, qconfig_dict) prepared(*inputs) qgraph = quantizer.convert(prepared) qgraph_debug = quantizer.convert(prepared, debug=True) result = qgraph(*inputs) result_debug = qgraph_debug(*inputs) self.assertEqual((result - result_debug).abs().max(), 0), \ 'Expecting debug and non-debug option to produce identical result' if debug: print() print('quant type:', quant_type) print('origianl graph module:', type(model)) self.printGraphModule(original) print() print('quantized graph module:', type(qgraph)) self.printGraphModule(qgraph) print() self.checkGraphModuleNodes( qgraph, expected_node, expected_node_occurrence, expected_node_list)
def test_functional(self): """ Test quantizing functional conv and linear """ class Conv(torch.nn.Module): def __init__(self): super().__init__() self.stride = (1, 1) self.padding = (0, 0) self.dilation = (1, 1) self.groups = 1 def forward(self, x, weight): return F.conv2d(x, weight, None, self.stride, self.padding, self.dilation, self.groups) conv_input = torch.rand(1, 3, 224, 224) conv_weight = torch.rand(3, 3, 3, 3) class Linear(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, weight): return F.linear(x, weight) linear_input = torch.rand(8, 5) linear_weight = torch.rand(10, 5) class LinearModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) linear_module_input = torch.rand(8, 5) tests = [ (False, Conv, (conv_input, conv_weight), ('call_function', torch.ops.quantized.conv2d)), (True, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear_dynamic)), (False, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear)), (True, LinearModule, (linear_module_input, ), ('call_module', torch.nn.quantized.dynamic.Linear)), (False, LinearModule, (linear_module_input, ), ('call_module', torch.nn.quantized.Linear)), ] for is_dynamic, M, inputs, quantized_node in tests: m = M().eval() qconfig = default_qconfig graph = symbolic_trace(m) script = torch.jit.script(graph) a = m(*inputs) b = graph(*inputs) c = script(*inputs) assert (a - b).abs().max() == 0 assert (a - c).abs().max() == 0 assert torch.allclose(a, b) assert torch.allclose(a, c) graph = fuse(graph) quantizer = Quantizer() qconfig_dict = {'': qconfig} if is_dynamic: prepared = quantizer.prepare_dynamic(graph, qconfig_dict) else: prepared = quantizer.prepare(graph, qconfig_dict) prepared(*inputs) qgraph = quantizer.convert(prepared) qgraph_debug = quantizer.convert(prepared, debug=True) qgraph.eval() qgraph_debug.eval() qgraph_script = torch.jit.script(qgraph) d = qgraph(*inputs) d_debug = qgraph_debug(*inputs) e = qgraph_script(*inputs) e_debug = qgraph_debug(*inputs) found = False modules = dict(qgraph.root.named_modules()) for node in qgraph.graph.nodes: if node.op == 'call_function': found = found or node.op == quantized_node[ 0] and node.target == quantized_node[1] elif node.op == 'call_module': found = found or node.op == quantized_node[0] and type( modules[node.target]) == quantized_node[1] assert found, 'Expected to find quantized node:' + str( quantized_op) # assert (a-d).abs().max() < 2 assert torch.allclose(d, e) assert (d - d_debug).abs().max() == 0 assert (e - e_debug).abs().max() == 0