コード例 #1
0
    def checkGraphModeFxOp(self, model, inputs, quantized_node, quant_type=QuantType.STATIC, debug=False):
        """ Quantizes model with graph mode quantization on fx and check if the
        quantized model contains the quantized_node

        Args:
            model: floating point torch.nn.Module
            inputs: one positional sample input arguments for model
            quantized_node: a tuple of 2 elements, first element is
                the op type for GraphModule node, second element
                is the target function for call_function and
                type of the module for call_module
        """
        # TODO: make img_data a single example instead of a list
        if type(inputs) == list:
            inputs = inputs[0]
        if quant_type == QuantType.QAT:
            model.train()
        else:
            model.eval()
        original = symbolic_trace(model)
        fused = fuse(original)

        quantizer = Quantizer()
        # TODO: uncommon after we make per channel observer work in the flow
        # qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
        qconfig_dict = {'' : default_qconfig}
        if quant_type == QuantType.DYNAMIC:
            prepared = quantizer.prepare_dynamic(fused, qconfig_dict)
        else:
            prepared = quantizer.prepare(fused, qconfig_dict)

        prepared(*inputs)
        qgraph = quantizer.convert(prepared)
        qgraph_debug = quantizer.convert(prepared, debug=True)

        result = qgraph(*inputs)
        result_debug = qgraph_debug(*inputs)

        self.assertEqual((result - result_debug).abs().max(), 0), \
            'Expecting debug and non-debug option to produce identical result'

        if debug:
            print()
            print('quant type:', quant_type)
            print('origianl graph module:', type(model))
            self.printGraphModule(original)
            print()
            print('quantized graph module:', type(qgraph))
            self.printGraphModule(qgraph)
            print()
        self.checkGraphModuleHasNode(qgraph, quantized_node)
コード例 #2
0
ファイル: test_quantize_fx.py プロジェクト: jegrace/pytorch
    def _test_model_impl(self,
                         mode,
                         name,
                         model,
                         eager_quantizable_model,
                         check_with_eager=True,
                         diff_of_quant=None,
                         diff_from_eager=None):
        if diff_of_quant is None or diff_from_eager is None:
            diff_of_quant = {}
            diff_from_eager = {}

        if mode not in diff_of_quant or mode not in diff_from_eager:
            diff_of_quant[mode] = {}
            diff_from_eager[mode] = {}

        input_tensor = torch.rand(1, 3, 224, 224)
        input_tensor_inception = torch.rand(1, 3, 299, 299)
        output_value = torch.randint(0, 1, (1, ))

        # print('quantizing:', name, ' mode:', mode)
        if name == 'inception_v3':
            input_value = input_tensor_inception
        else:
            input_value = input_tensor

        qconfig = default_qconfig if mode == 'static' else default_qat_qconfig
        qconfig_dict = {'': qconfig}
        graph_module = symbolic_trace(model)
        # print('graph module:', graph_module.src)
        script = torch.jit.script(graph_module)

        # make sure graph module and script module are both runanble
        original_out = graph_module(input_value)
        is_not_tuple_out = not isinstance(original_out, tuple)
        script_out = script(input_value)
        self.assertEqual(
            (original_out - script_out).abs().max(), 0,
            'Reslut of original graph module and script module does not match')

        # set to train just before quantization
        if mode != 'static':
            model.train()

        graph_module = fuse(graph_module)
        quantizer = Quantizer()
        prepared = quantizer.prepare(graph_module, qconfig_dict)

        if mode == 'ddp':
            mp.spawn(run_ddp,
                     args=(world_size, prepared),
                     nprocs=world_size,
                     join=True)
        elif mode == 'qat':
            assert prepared.training, 'prepared must be in training mode for qat'
            optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001)
            criterion = nn.CrossEntropyLoss()
            train_one_epoch(prepared, criterion,
                            optimizer, [(input_value, output_value)],
                            torch.device('cpu'), 1)
        else:
            for i in range(10):
                prepared(input_value)

        # print('after observation root:', prepared.root)

        qgraph = quantizer.convert(prepared)
        # print('after quantization root:', qgraph.root)
        # print('after quantization code:', qgraph.src)
        qgraph.eval()
        qgraph_script = torch.jit.script(qgraph)
        # print('quantized and scripted:', qgraph_script.graph)

        qgraph_out = qgraph(input_value)
        qgraph_script = qgraph_script(input_value)

        if is_not_tuple_out:
            diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max()
            assert torch.allclose(qgraph_out,
                                  qgraph_script), 'graph, scripted graph'
        else:
            print('tuple output')

        if eager_quantizable_model is not None:
            # comparing to eager mode quantization
            qeager = eager_quantizable_model
            ref_out = qeager(input_value)
            qeager.qconfig = qconfig
            if mode == 'static':
                qeager.fuse_model()
                prepare(qeager, inplace=True)
            else:
                qeager.train()
                qeager.fuse_model()
                prepare_qat(qeager, inplace=True)

            # calibration
            if mode == 'ddp':
                mp.spawn(run_ddp,
                         args=(world_size, qeager),
                         nprocs=world_size,
                         join=True)
            elif mode == 'qat':
                assert qeager.training, 'qeager should be in training mode for qat'
                optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001)
                train_one_epoch(qeager, criterion, optimizer,
                                [(input_value, output_value)],
                                torch.device('cpu'), 1)
            else:
                for i in range(10):
                    qeager(input_value)

            # print('ref after observation:', qeager)

            convert(qeager, inplace=True)
            qeager.eval()

            # print('ref after quantization:', qeager)
            qeager_out = qeager(input_value)
            qeager_script = torch.jit.script(qeager)
            qscript_out = qeager_script(input_value)
            if is_not_tuple_out:
                diff_from_eager[mode][name] = (qeager_out -
                                               qgraph_out).abs().max()
                if check_with_eager:
                    self.assertEqual(
                        diff_from_eager[mode][name], 0,
                        'Result of graph mode quantization and ' +
                        'eager mode quantization on model: ' + name +
                        ' should match. Mode: ' + mode + ' diff:' +
                        str(diff_from_eager[mode][name]))
コード例 #3
0
ファイル: test_quantize_fx.py プロジェクト: jegrace/pytorch
    def test_general_value_ops(self):
        """ A test that checks correct patterns are produced for
        all supported general value ops like aten::avg_pool2d \
        without actually checking for execution of these ops
        """
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.avg_pool1d = torch.nn.AvgPool1d(3)
                self.avg_pool2d = torch.nn.AvgPool2d(3)
                self.avg_pool3d = torch.nn.AvgPool3d(3)
                self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
                self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
                self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d(
                    (1, 1, 1))
                self.leaky_relu = torch.nn.LeakyReLU()
                self.hardsigmoid = torch.nn.Hardsigmoid()
                self.sigmoid = torch.nn.Sigmoid()
                self.tanh = torch.nn.Tanh()

            def forward(self, x):
                x = self.conv(x)
                x = self.avg_pool1d(x)
                x = self.avg_pool2d(x)
                x = self.avg_pool3d(x)
                x = self.adaptive_avg_pool1d(x)
                x = self.adaptive_avg_pool2d(x)
                x = self.adaptive_avg_pool3d(x)
                x = F.avg_pool1d(x, 3)
                x = F.avg_pool2d(x, 3)
                x = F.avg_pool3d(x, 3)
                x = F.adaptive_avg_pool1d(x, (1))
                x = F.adaptive_avg_pool2d(x, (1, 1))
                x = F.adaptive_avg_pool3d(x, (1, 1, 1))
                x = torch.mean(x)
                x = torch.mean(x, [2, 3], False)
                x = x.mean()
                x = x.mean([2, 3], True)
                x = F.interpolate(x, 4, mode='nearest')
                x = F.interpolate(x, 4, mode='linear')
                x = self.leaky_relu(x)
                x = F.leaky_relu(x)
                x = F.leaky_relu(x, inplace=True)
                x = x.leaky_relu()
                x.leaky_relu_()
                x = self.hardsigmoid(x)
                x = F.hardsigmoid(x)
                x = F.hardsigmoid(x, inplace=True)
                x = x.hardsigmoid()
                x.hardsigmoid_()
                x = self.sigmoid(x)
                x = torch.sigmoid(x)
                # F.sigmoid is deprecated
                x = x.sigmoid()
                x.sigmoid_()
                x = self.tanh(x)
                # F.tanh is deprecated
                x = torch.tanh(x)
                x = x.tanh()
                x.tanh_()
                x = self.conv(x)
                return x

        # This model is not executable since we just put all ops
        # in the same forward
        m = M()
        original = symbolic_trace(m)
        # nothing to fuse so skipping the fuse step
        quantizer = Quantizer()
        qconfig_dict = {'': default_qconfig}
        prepared = quantizer.prepare(original, qconfig_dict)
        # not runnable
        quantized = quantizer.convert(prepared)

        # This checks that the dequantize from the output of first conv
        # is being propagated to the end, so that we don't insert extra
        # observers
        # check exact counts of quantize and dequantize
        count_check = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method('dequantize'): 1
        }
        order_check = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
        ]
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=count_check,
                                   expected_node_list=order_check)
コード例 #4
0
ファイル: test_quantize_fx.py プロジェクト: jegrace/pytorch
    def test_general_shape_ops(self):
        """ A test that checks dequantize will be swapped for
        all supported general shape ops like aten::flatten
        without actually checking for execution of these ops
        """
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3)
                self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
                self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
                self.dropout = torch.nn.Dropout()
                self.conv1 = torch.nn.Conv2d(3, 3, 3)
                self.conv2 = torch.nn.Conv2d(3, 3, 3)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                x = self.conv1(x)
                # add_scalar
                x = x + 3
                # mul_scalar
                x = x * 3
                # add_scalar_out
                x += 3
                # mul_scalar_out
                x *= 3
                # add_scalar_relu
                x = x + 3
                x = F.relu(x)
                # add_scalar_relu_out
                x += 3
                x = F.relu(x)
                # mul_scalar_relu
                x = x * 3
                x = F.relu(x)
                # mul_scalar_relu_out
                x *= 3
                x = F.relu(x)
                x = self.maxpool1d(x)
                x = self.maxpool2d(x)
                x = self.maxpool3d(x)
                x = torch.flatten(x)
                x = torch.max(x)
                x = torch.min(x)
                x = x.reshape([-1])
                x = x.resize_(1, 1, x.numel())
                x = x.view(-1)
                # prim::ListConstruct
                xs = [x, x]
                # prim::ListUnpack
                x, y = xs
                # prim::TupleConstruct
                xs = (x, x)
                # prim::TupleUnpack
                x, y = xs
                x = x.transpose(1, 2)
                x = x.contiguous()
                x, y = torch.chunk(x, 2)
                x = F.dropout(x)
                x = self.dropout(x)
                x, _ = torch.sort(x)
                x = x.permute(0, 2, 3, 1)
                x = x.repeat_interleave(3, 1)
                x = torch.repeat_interleave(x, 3, 1)
                x = self.relu(x)
                x = F.relu(x)
                x = F.relu(x, inplace=True)
                x = x.relu()
                x.relu_()
                x = x.squeeze(0)
                x.squeeze_(0)
                x = torch.squeeze(x, 0)
                x = x.unsqueeze(0)
                x.unsqueeze_(0)
                x = torch.unsqueeze(x, 0)
                x = x.detach()
                x.detach_()
                x = x.repeat(4, 2)
                y = []
                y.append(x)
                z = torch.stack(y, 0)
                z = [z, z]
                x, _ = z
                x = self.conv2(x)
                return x

        data = torch.rand(1, 3, 10, 10)
        # This model is not executable since we just put all ops
        # in the same forward
        m = M()
        original = symbolic_trace(m)
        # nothing to fuse so skipping the fuse step
        quantizer = Quantizer()
        qconfig_dict = {'': default_qconfig}
        prepared = quantizer.prepare(original, qconfig_dict)
        # not runnable
        quantized = quantizer.convert(prepared)

        # This checks that the dequantize from the output of first conv
        # is being propagated to the end, so that we don't insert extra
        # observers and also successfully fused two quantized::conv2d
        # patterns
        # one quantize_per_tensor for input
        # check exact counts of quantize and dequantize
        count_check = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method('dequantize'): 1
        }
        order_check = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
        ]
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=count_check,
                                   expected_node_list=order_check)
コード例 #5
0
    def checkGraphModeFxOp(self, model, inputs, quant_type,
                           expected_node=None,
                           expected_node_occurrence=None,
                           expected_node_list=None,
                           debug=False):
        """ Quantizes model with graph mode quantization on fx and check if the
        quantized model contains the quantized_node

        Args:
            model: floating point torch.nn.Module
            inputs: one positional sample input arguments for model
            expected_node: NodeSpec
                  e.g. NodeSpec.call_function(torch.quantize_per_tensor)
            expected_node_occurrence: a dict from NodeSpec to
                  expected number of occurences (int)
                  e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1,
                        NodeSpec.call_method('dequantize'): 1}
            expected_node_list: a list of NodeSpec, used to check the order
                  of the occurrence of Node
                  e.g. [NodeSpec.call_function(torch.quantize_per_tensor),
                        NodeSpec.call_module(nnq.Conv2d),
                        NodeSpec.call_function(F.hardtanh_),
                        NodeSpec.call_method('dequantize')]
        """
        # TODO: make img_data a single example instead of a list
        if type(inputs) == list:
            inputs = inputs[0]
        if quant_type == QuantType.QAT:
            model.train()
        else:
            model.eval()
        original = symbolic_trace(model)
        fused = fuse(original)

        quantizer = Quantizer()
        qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
        if quant_type == QuantType.DYNAMIC:
            prepared = quantizer.prepare_dynamic(fused, qconfig_dict)
        else:
            prepared = quantizer.prepare(fused, qconfig_dict)

        prepared(*inputs)
        qgraph = quantizer.convert(prepared)
        qgraph_debug = quantizer.convert(prepared, debug=True)

        result = qgraph(*inputs)
        result_debug = qgraph_debug(*inputs)

        self.assertEqual((result - result_debug).abs().max(), 0), \
            'Expecting debug and non-debug option to produce identical result'

        if debug:
            print()
            print('quant type:', quant_type)
            print('origianl graph module:', type(model))
            self.printGraphModule(original)
            print()
            print('quantized graph module:', type(qgraph))
            self.printGraphModule(qgraph)
            print()
        self.checkGraphModuleNodes(
            qgraph, expected_node, expected_node_occurrence, expected_node_list)
コード例 #6
0
ファイル: test_quantize_fx.py プロジェクト: bradleyhd/pytorch
    def test_functional(self):
        """ Test quantizing functional conv and linear
        """
        class Conv(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.stride = (1, 1)
                self.padding = (0, 0)
                self.dilation = (1, 1)
                self.groups = 1

            def forward(self, x, weight):
                return F.conv2d(x, weight, None, self.stride, self.padding,
                                self.dilation, self.groups)

        conv_input = torch.rand(1, 3, 224, 224)
        conv_weight = torch.rand(3, 3, 3, 3)

        class Linear(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, weight):
                return F.linear(x, weight)

        linear_input = torch.rand(8, 5)
        linear_weight = torch.rand(10, 5)

        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        tests = [
            (False, Conv, (conv_input, conv_weight),
             ('call_function', torch.ops.quantized.conv2d)),
            (True, Linear, (linear_input, linear_weight),
             ('call_function', torch.ops.quantized.linear_dynamic)),
            (False, Linear, (linear_input, linear_weight),
             ('call_function', torch.ops.quantized.linear)),
            (True, LinearModule, (linear_module_input, ),
             ('call_module', torch.nn.quantized.dynamic.Linear)),
            (False, LinearModule, (linear_module_input, ),
             ('call_module', torch.nn.quantized.Linear)),
        ]

        for is_dynamic, M, inputs, quantized_node in tests:
            m = M().eval()
            qconfig = default_qconfig

            graph = symbolic_trace(m)
            script = torch.jit.script(graph)

            a = m(*inputs)
            b = graph(*inputs)
            c = script(*inputs)
            assert (a - b).abs().max() == 0
            assert (a - c).abs().max() == 0
            assert torch.allclose(a, b)
            assert torch.allclose(a, c)

            graph = fuse(graph)

            quantizer = Quantizer()
            qconfig_dict = {'': qconfig}
            if is_dynamic:
                prepared = quantizer.prepare_dynamic(graph, qconfig_dict)
            else:
                prepared = quantizer.prepare(graph, qconfig_dict)

            prepared(*inputs)

            qgraph = quantizer.convert(prepared)
            qgraph_debug = quantizer.convert(prepared, debug=True)
            qgraph.eval()
            qgraph_debug.eval()
            qgraph_script = torch.jit.script(qgraph)

            d = qgraph(*inputs)
            d_debug = qgraph_debug(*inputs)
            e = qgraph_script(*inputs)
            e_debug = qgraph_debug(*inputs)

            found = False
            modules = dict(qgraph.root.named_modules())
            for node in qgraph.graph.nodes:
                if node.op == 'call_function':
                    found = found or node.op == quantized_node[
                        0] and node.target == quantized_node[1]
                elif node.op == 'call_module':
                    found = found or node.op == quantized_node[0] and type(
                        modules[node.target]) == quantized_node[1]
            assert found, 'Expected to find quantized node:' + str(
                quantized_op)
            # assert (a-d).abs().max() < 2
            assert torch.allclose(d, e)
            assert (d - d_debug).abs().max() == 0
            assert (e - e_debug).abs().max() == 0