예제 #1
0
    def checkGraphModeFxOp(self, model, inputs, quantized_node, quant_type=QuantType.STATIC, debug=False):
        """ Quantizes model with graph mode quantization on fx and check if the
        quantized model contains the quantized_node

        Args:
            model: floating point torch.nn.Module
            inputs: one positional sample input arguments for model
            quantized_node: a tuple of 2 elements, first element is
                the op type for GraphModule node, second element
                is the target function for call_function and
                type of the module for call_module
        """
        # TODO: make img_data a single example instead of a list
        if type(inputs) == list:
            inputs = inputs[0]
        if quant_type == QuantType.QAT:
            model.train()
        else:
            model.eval()
        original = symbolic_trace(model)
        fused = fuse(original)

        quantizer = Quantizer()
        # TODO: uncommon after we make per channel observer work in the flow
        # qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
        qconfig_dict = {'' : default_qconfig}
        if quant_type == QuantType.DYNAMIC:
            prepared = quantizer.prepare_dynamic(fused, qconfig_dict)
        else:
            prepared = quantizer.prepare(fused, qconfig_dict)

        prepared(*inputs)
        qgraph = quantizer.convert(prepared)
        qgraph_debug = quantizer.convert(prepared, debug=True)

        result = qgraph(*inputs)
        result_debug = qgraph_debug(*inputs)

        self.assertEqual((result - result_debug).abs().max(), 0), \
            'Expecting debug and non-debug option to produce identical result'

        if debug:
            print()
            print('quant type:', quant_type)
            print('origianl graph module:', type(model))
            self.printGraphModule(original)
            print()
            print('quantized graph module:', type(qgraph))
            self.printGraphModule(qgraph)
            print()
        self.checkGraphModuleHasNode(qgraph, quantized_node)
예제 #2
0
    def _test_model_impl(self,
                         mode,
                         name,
                         model,
                         eager_quantizable_model,
                         check_with_eager=True,
                         diff_of_quant=None,
                         diff_from_eager=None):
        if diff_of_quant is None or diff_from_eager is None:
            diff_of_quant = {}
            diff_from_eager = {}

        if mode not in diff_of_quant or mode not in diff_from_eager:
            diff_of_quant[mode] = {}
            diff_from_eager[mode] = {}

        input_tensor = torch.rand(1, 3, 224, 224)
        input_tensor_inception = torch.rand(1, 3, 299, 299)
        output_value = torch.randint(0, 1, (1, ))

        # print('quantizing:', name, ' mode:', mode)
        if name == 'inception_v3':
            input_value = input_tensor_inception
        else:
            input_value = input_tensor

        qconfig = default_qconfig if mode == 'static' else default_qat_qconfig
        qconfig_dict = {'': qconfig}
        graph_module = symbolic_trace(model)
        # print('graph module:', graph_module.src)
        script = torch.jit.script(graph_module)

        # make sure graph module and script module are both runanble
        original_out = graph_module(input_value)
        is_not_tuple_out = not isinstance(original_out, tuple)
        script_out = script(input_value)
        self.assertEqual(
            (original_out - script_out).abs().max(), 0,
            'Reslut of original graph module and script module does not match')

        # set to train just before quantization
        if mode != 'static':
            model.train()

        graph_module = fuse(graph_module)
        quantizer = Quantizer()
        prepared = quantizer.prepare(graph_module, qconfig_dict)

        if mode == 'ddp':
            mp.spawn(run_ddp,
                     args=(world_size, prepared),
                     nprocs=world_size,
                     join=True)
        elif mode == 'qat':
            assert prepared.training, 'prepared must be in training mode for qat'
            optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001)
            criterion = nn.CrossEntropyLoss()
            train_one_epoch(prepared, criterion,
                            optimizer, [(input_value, output_value)],
                            torch.device('cpu'), 1)
        else:
            for i in range(10):
                prepared(input_value)

        # print('after observation root:', prepared.root)

        qgraph = quantizer.convert(prepared)
        # print('after quantization root:', qgraph.root)
        # print('after quantization code:', qgraph.src)
        qgraph.eval()
        qgraph_script = torch.jit.script(qgraph)
        # print('quantized and scripted:', qgraph_script.graph)

        qgraph_out = qgraph(input_value)
        qgraph_script = qgraph_script(input_value)

        if is_not_tuple_out:
            diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max()
            assert torch.allclose(qgraph_out,
                                  qgraph_script), 'graph, scripted graph'
        else:
            print('tuple output')

        if eager_quantizable_model is not None:
            # comparing to eager mode quantization
            qeager = eager_quantizable_model
            ref_out = qeager(input_value)
            qeager.qconfig = qconfig
            if mode == 'static':
                qeager.fuse_model()
                prepare(qeager, inplace=True)
            else:
                qeager.train()
                qeager.fuse_model()
                prepare_qat(qeager, inplace=True)

            # calibration
            if mode == 'ddp':
                mp.spawn(run_ddp,
                         args=(world_size, qeager),
                         nprocs=world_size,
                         join=True)
            elif mode == 'qat':
                assert qeager.training, 'qeager should be in training mode for qat'
                optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001)
                train_one_epoch(qeager, criterion, optimizer,
                                [(input_value, output_value)],
                                torch.device('cpu'), 1)
            else:
                for i in range(10):
                    qeager(input_value)

            # print('ref after observation:', qeager)

            convert(qeager, inplace=True)
            qeager.eval()

            # print('ref after quantization:', qeager)
            qeager_out = qeager(input_value)
            qeager_script = torch.jit.script(qeager)
            qscript_out = qeager_script(input_value)
            if is_not_tuple_out:
                diff_from_eager[mode][name] = (qeager_out -
                                               qgraph_out).abs().max()
                if check_with_eager:
                    self.assertEqual(
                        diff_from_eager[mode][name], 0,
                        'Result of graph mode quantization and ' +
                        'eager mode quantization on model: ' + name +
                        ' should match. Mode: ' + mode + ' diff:' +
                        str(diff_from_eager[mode][name]))
예제 #3
0
    def checkGraphModeFxOp(self, model, inputs, quant_type,
                           expected_node=None,
                           expected_node_occurrence=None,
                           expected_node_list=None,
                           debug=False):
        """ Quantizes model with graph mode quantization on fx and check if the
        quantized model contains the quantized_node

        Args:
            model: floating point torch.nn.Module
            inputs: one positional sample input arguments for model
            expected_node: NodeSpec
                  e.g. NodeSpec.call_function(torch.quantize_per_tensor)
            expected_node_occurrence: a dict from NodeSpec to
                  expected number of occurences (int)
                  e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1,
                        NodeSpec.call_method('dequantize'): 1}
            expected_node_list: a list of NodeSpec, used to check the order
                  of the occurrence of Node
                  e.g. [NodeSpec.call_function(torch.quantize_per_tensor),
                        NodeSpec.call_module(nnq.Conv2d),
                        NodeSpec.call_function(F.hardtanh_),
                        NodeSpec.call_method('dequantize')]
        """
        # TODO: make img_data a single example instead of a list
        if type(inputs) == list:
            inputs = inputs[0]
        if quant_type == QuantType.QAT:
            model.train()
        else:
            model.eval()
        original = symbolic_trace(model)
        fused = fuse(original)

        quantizer = Quantizer()
        qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
        if quant_type == QuantType.DYNAMIC:
            prepared = quantizer.prepare_dynamic(fused, qconfig_dict)
        else:
            prepared = quantizer.prepare(fused, qconfig_dict)

        prepared(*inputs)
        qgraph = quantizer.convert(prepared)
        qgraph_debug = quantizer.convert(prepared, debug=True)

        result = qgraph(*inputs)
        result_debug = qgraph_debug(*inputs)

        self.assertEqual((result - result_debug).abs().max(), 0), \
            'Expecting debug and non-debug option to produce identical result'

        if debug:
            print()
            print('quant type:', quant_type)
            print('origianl graph module:', type(model))
            self.printGraphModule(original)
            print()
            print('quantized graph module:', type(qgraph))
            self.printGraphModule(qgraph)
            print()
        self.checkGraphModuleNodes(
            qgraph, expected_node, expected_node_occurrence, expected_node_list)
예제 #4
0
    def test_functional(self):
        """ Test quantizing functional conv and linear
        """
        class Conv(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.stride = (1, 1)
                self.padding = (0, 0)
                self.dilation = (1, 1)
                self.groups = 1

            def forward(self, x, weight):
                return F.conv2d(x, weight, None, self.stride, self.padding,
                                self.dilation, self.groups)

        conv_input = torch.rand(1, 3, 224, 224)
        conv_weight = torch.rand(3, 3, 3, 3)

        class Linear(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, weight):
                return F.linear(x, weight)

        linear_input = torch.rand(8, 5)
        linear_weight = torch.rand(10, 5)

        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        tests = [
            (False, Conv, (conv_input, conv_weight),
             ('call_function', torch.ops.quantized.conv2d)),
            (True, Linear, (linear_input, linear_weight),
             ('call_function', torch.ops.quantized.linear_dynamic)),
            (False, Linear, (linear_input, linear_weight),
             ('call_function', torch.ops.quantized.linear)),
            (True, LinearModule, (linear_module_input, ),
             ('call_module', torch.nn.quantized.dynamic.Linear)),
            (False, LinearModule, (linear_module_input, ),
             ('call_module', torch.nn.quantized.Linear)),
        ]

        for is_dynamic, M, inputs, quantized_node in tests:
            m = M().eval()
            qconfig = default_qconfig

            graph = symbolic_trace(m)
            script = torch.jit.script(graph)

            a = m(*inputs)
            b = graph(*inputs)
            c = script(*inputs)
            assert (a - b).abs().max() == 0
            assert (a - c).abs().max() == 0
            assert torch.allclose(a, b)
            assert torch.allclose(a, c)

            graph = fuse(graph)

            quantizer = Quantizer()
            qconfig_dict = {'': qconfig}
            if is_dynamic:
                prepared = quantizer.prepare_dynamic(graph, qconfig_dict)
            else:
                prepared = quantizer.prepare(graph, qconfig_dict)

            prepared(*inputs)

            qgraph = quantizer.convert(prepared)
            qgraph_debug = quantizer.convert(prepared, debug=True)
            qgraph.eval()
            qgraph_debug.eval()
            qgraph_script = torch.jit.script(qgraph)

            d = qgraph(*inputs)
            d_debug = qgraph_debug(*inputs)
            e = qgraph_script(*inputs)
            e_debug = qgraph_debug(*inputs)

            found = False
            modules = dict(qgraph.root.named_modules())
            for node in qgraph.graph.nodes:
                if node.op == 'call_function':
                    found = found or node.op == quantized_node[
                        0] and node.target == quantized_node[1]
                elif node.op == 'call_module':
                    found = found or node.op == quantized_node[0] and type(
                        modules[node.target]) == quantized_node[1]
            assert found, 'Expected to find quantized node:' + str(
                quantized_op)
            # assert (a-d).abs().max() < 2
            assert torch.allclose(d, e)
            assert (d - d_debug).abs().max() == 0
            assert (e - e_debug).abs().max() == 0