Exemplo n.º 1
0
    def checkGraphModeFxOp(self, model, inputs, quantized_node, quant_type=QuantType.STATIC, debug=False):
        """ Quantizes model with graph mode quantization on fx and check if the
        quantized model contains the quantized_node

        Args:
            model: floating point torch.nn.Module
            inputs: one positional sample input arguments for model
            quantized_node: a tuple of 2 elements, first element is
                the op type for GraphModule node, second element
                is the target function for call_function and
                type of the module for call_module
        """
        # TODO: make img_data a single example instead of a list
        if type(inputs) == list:
            inputs = inputs[0]
        if quant_type == QuantType.QAT:
            model.train()
        else:
            model.eval()
        original = symbolic_trace(model)
        fused = fuse(original)

        quantizer = Quantizer()
        # TODO: uncommon after we make per channel observer work in the flow
        # qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
        qconfig_dict = {'' : default_qconfig}
        if quant_type == QuantType.DYNAMIC:
            prepared = quantizer.prepare_dynamic(fused, qconfig_dict)
        else:
            prepared = quantizer.prepare(fused, qconfig_dict)

        prepared(*inputs)
        qgraph = quantizer.convert(prepared)
        qgraph_debug = quantizer.convert(prepared, debug=True)

        result = qgraph(*inputs)
        result_debug = qgraph_debug(*inputs)

        self.assertEqual((result - result_debug).abs().max(), 0), \
            'Expecting debug and non-debug option to produce identical result'

        if debug:
            print()
            print('quant type:', quant_type)
            print('origianl graph module:', type(model))
            self.printGraphModule(original)
            print()
            print('quantized graph module:', type(qgraph))
            self.printGraphModule(qgraph)
            print()
        self.checkGraphModuleHasNode(qgraph, quantized_node)
Exemplo n.º 2
0
    def checkGraphModeFxOp(self, model, inputs, quant_type,
                           expected_node=None,
                           expected_node_occurrence=None,
                           expected_node_list=None,
                           debug=False):
        """ Quantizes model with graph mode quantization on fx and check if the
        quantized model contains the quantized_node

        Args:
            model: floating point torch.nn.Module
            inputs: one positional sample input arguments for model
            expected_node: NodeSpec
                  e.g. NodeSpec.call_function(torch.quantize_per_tensor)
            expected_node_occurrence: a dict from NodeSpec to
                  expected number of occurences (int)
                  e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1,
                        NodeSpec.call_method('dequantize'): 1}
            expected_node_list: a list of NodeSpec, used to check the order
                  of the occurrence of Node
                  e.g. [NodeSpec.call_function(torch.quantize_per_tensor),
                        NodeSpec.call_module(nnq.Conv2d),
                        NodeSpec.call_function(F.hardtanh_),
                        NodeSpec.call_method('dequantize')]
        """
        # TODO: make img_data a single example instead of a list
        if type(inputs) == list:
            inputs = inputs[0]
        if quant_type == QuantType.QAT:
            model.train()
        else:
            model.eval()
        original = symbolic_trace(model)
        fused = fuse(original)

        quantizer = Quantizer()
        qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
        if quant_type == QuantType.DYNAMIC:
            prepared = quantizer.prepare_dynamic(fused, qconfig_dict)
        else:
            prepared = quantizer.prepare(fused, qconfig_dict)

        prepared(*inputs)
        qgraph = quantizer.convert(prepared)
        qgraph_debug = quantizer.convert(prepared, debug=True)

        result = qgraph(*inputs)
        result_debug = qgraph_debug(*inputs)

        self.assertEqual((result - result_debug).abs().max(), 0), \
            'Expecting debug and non-debug option to produce identical result'

        if debug:
            print()
            print('quant type:', quant_type)
            print('origianl graph module:', type(model))
            self.printGraphModule(original)
            print()
            print('quantized graph module:', type(qgraph))
            self.printGraphModule(qgraph)
            print()
        self.checkGraphModuleNodes(
            qgraph, expected_node, expected_node_occurrence, expected_node_list)
Exemplo n.º 3
0
    def test_functional(self):
        """ Test quantizing functional conv and linear
        """
        class Conv(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.stride = (1, 1)
                self.padding = (0, 0)
                self.dilation = (1, 1)
                self.groups = 1

            def forward(self, x, weight):
                return F.conv2d(x, weight, None, self.stride, self.padding,
                                self.dilation, self.groups)

        conv_input = torch.rand(1, 3, 224, 224)
        conv_weight = torch.rand(3, 3, 3, 3)

        class Linear(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, weight):
                return F.linear(x, weight)

        linear_input = torch.rand(8, 5)
        linear_weight = torch.rand(10, 5)

        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        tests = [
            (False, Conv, (conv_input, conv_weight),
             ('call_function', torch.ops.quantized.conv2d)),
            (True, Linear, (linear_input, linear_weight),
             ('call_function', torch.ops.quantized.linear_dynamic)),
            (False, Linear, (linear_input, linear_weight),
             ('call_function', torch.ops.quantized.linear)),
            (True, LinearModule, (linear_module_input, ),
             ('call_module', torch.nn.quantized.dynamic.Linear)),
            (False, LinearModule, (linear_module_input, ),
             ('call_module', torch.nn.quantized.Linear)),
        ]

        for is_dynamic, M, inputs, quantized_node in tests:
            m = M().eval()
            qconfig = default_qconfig

            graph = symbolic_trace(m)
            script = torch.jit.script(graph)

            a = m(*inputs)
            b = graph(*inputs)
            c = script(*inputs)
            assert (a - b).abs().max() == 0
            assert (a - c).abs().max() == 0
            assert torch.allclose(a, b)
            assert torch.allclose(a, c)

            graph = fuse(graph)

            quantizer = Quantizer()
            qconfig_dict = {'': qconfig}
            if is_dynamic:
                prepared = quantizer.prepare_dynamic(graph, qconfig_dict)
            else:
                prepared = quantizer.prepare(graph, qconfig_dict)

            prepared(*inputs)

            qgraph = quantizer.convert(prepared)
            qgraph_debug = quantizer.convert(prepared, debug=True)
            qgraph.eval()
            qgraph_debug.eval()
            qgraph_script = torch.jit.script(qgraph)

            d = qgraph(*inputs)
            d_debug = qgraph_debug(*inputs)
            e = qgraph_script(*inputs)
            e_debug = qgraph_debug(*inputs)

            found = False
            modules = dict(qgraph.root.named_modules())
            for node in qgraph.graph.nodes:
                if node.op == 'call_function':
                    found = found or node.op == quantized_node[
                        0] and node.target == quantized_node[1]
                elif node.op == 'call_module':
                    found = found or node.op == quantized_node[0] and type(
                        modules[node.target]) == quantized_node[1]
            assert found, 'Expected to find quantized node:' + str(
                quantized_op)
            # assert (a-d).abs().max() < 2
            assert torch.allclose(d, e)
            assert (d - d_debug).abs().max() == 0
            assert (e - e_debug).abs().max() == 0