Beispiel #1
0
    def test_qconfig_none(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv1 = nn.Conv2d(1, 1, 1)
                self.conv2 = nn.Conv2d(1, 1, 1)

            def forward(self, x):
                x = self.conv1(x)
                x = self.conv2(x)
                return x

        m = M().eval()
        m = symbolic_trace(m)
        qconfig_dict = {'': default_qconfig, 'conv2': None}
        m = prepare_static_fx(m, qconfig_dict)
        data = torch.randn(1, 1, 1, 1)
        m(data)
        m = convert_static_fx(m)
        m(data)
        # first conv is quantized, second conv is not quantized
        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
            ns.call_module(nn.Conv2d),
        ]
        self.checkGraphModuleNodes(m, expected_node_list=node_list)
Beispiel #2
0
    def test_dict_output(self):
        """ Make sure quantization runs for models with dictionary output
        """
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)

            def forward(self, x):
                return {"output": self.conv(x["input"])}

        dict_input = {"input": torch.randn(1, 1, 1, 1)}
        m = symbolic_trace(M()).eval()
        qconfig_dict = {"": default_qconfig}
        m = prepare_static_fx(m, qconfig_dict)
        m(dict_input)
        m = convert_static_fx(m)
        m(dict_input)
Beispiel #3
0
    def test_remove_qconfig(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.avg_pool = torch.nn.AvgPool2d(1)

            def forward(self, x):
                return self.avg_pool(x)

        m = M().eval()
        m = symbolic_trace(m)
        qconfig_dict = {'': default_qconfig}
        m = prepare_static_fx(m, qconfig_dict)
        data = torch.randn(1, 1, 1, 1)
        m(data)
        m = convert_static_fx(m)
        m(data)
        for name, module in m.named_modules():
            self.assertFalse(hasattr(module, 'qconfig'),
                             'qconfig is not removed for ' + name)