Example #1
0
    def test_insert_quant_dequant_conv_dynamic(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv = torch.nn.Conv2d(3, 5, 3).float()

            def forward(self, x):
                return self.conv(x)

        m = torch.jit.script(M())

        m = prepare_dynamic_script(m, {'': default_dynamic_qconfig})
        data = torch.randn(1, 3, 10, 10, dtype=torch.float)

        m(data)

        m = wrap_cpp_module(
            torch._C._jit_pass_insert_quant_dequant(m._c, "forward", False,
                                                    True))

        assert len(m._modules._c.items()) == 1, \
            'Expected to have single submodule of conv'

        m(data)
        quant_func = "aten::quantize_per_tensor"

        # quantizing activations
        FileCheck().check("aten::_choose_qparams_per_tensor") \
                   .check(quant_func) \
                   .check("prim::CallMethod[name=\"forward\"]") \
                   .check_not(quant_func) \
                   .check("return") \
                   .run(str(get_forward_graph(m._c)))
        # quantizing weight in forward function of conv module, no choose_qparams
        FileCheck().check_not("aten::_choose_qparams_per_tensor") \
                   .check(quant_func) \
                   .check("prim::CallMethod[name=\"_conv_forward\"]") \
                   .check_not(quant_func) \
                   .check("return") \
                   .run(str(get_forward_graph(m.conv._c)))
        # shouldn't have quant/dequant in _conv_foward function
        FileCheck().check_not(quant_func) \
                   .check("aten::conv2d") \
                   .check_not(quant_func) \
                   .check("return") \
                   .run(str(get_module_method(m, 'conv', '_conv_forward').graph))
Example #2
0
    def test_insert_quant_dequant_linear_dynamic(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.fc1 = torch.nn.Linear(5, 5).float()
                self.fc2 = torch.nn.Linear(5, 5).float()

            def forward(self, x):
                x = self.fc1(x)
                return self.fc2(x)

        m = torch.jit.script(M())

        m = prepare_dynamic_script(m, {'': default_dynamic_qconfig})
        data = torch.randn(5, 5, dtype=torch.float)

        m(data)
        m = wrap_cpp_module(
            torch._C._jit_pass_insert_quant_dequant(m._c, "forward", False,
                                                    True))

        assert len(m._modules._c.items()) == 2, \
            'Expected to have two submodule of linear'

        m(data)
        quant_func = "aten::quantize_per_tensor"

        # quantizing activations
        FileCheck().check("aten::_choose_qparams_per_tensor") \
                   .check(quant_func) \
                   .check("prim::CallMethod[name=\"forward\"]") \
                   .check("aten::_choose_qparams_per_tensor") \
                   .check(quant_func) \
                   .check("prim::CallMethod[name=\"forward\"]") \
                   .check_not(quant_func) \
                   .check("return") \
                   .run(str(get_forward_graph(m._c)))
        # quantizing weight in forward function of fc module, no choose_qparams
        FileCheck().check_not("aten::_choose_qparams_per_tensor") \
                   .check(quant_func) \
                   .check("prim::CallFunction") \
                   .check_not(quant_func) \
                   .check("return") \
                   .run(str(get_forward_graph(m.fc1._c)))
    def test_optimize_for_mobile(self):
        batch_size = 2
        input_channels_per_group = 6
        height = 16
        width = 16
        output_channels_per_group = 6
        groups = 4
        kernel_h = kernel_w = 3
        stride_h = stride_w = 1
        pad_h = pad_w = 1
        dilation = 1
        input_channels = input_channels_per_group * groups
        output_channels = output_channels_per_group * groups
        kernels = (kernel_h, kernel_w)
        strides = (stride_h, stride_w)
        paddings = (pad_h, pad_w)
        dilations = (dilation, dilation)
        conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
        conv_bias_shape = (output_channels)

        input_data = torch.rand((batch_size, input_channels, height, width))
        conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
        conv_bias = torch.rand((output_channels))
        result = F.conv2d(input_data, conv_weight, conv_bias, strides, paddings, dilations, groups)
        weight_output_dim = 24
        linear_input_shape = result.shape[1]
        linear_weight_shape = (weight_output_dim, linear_input_shape)

        class MyTestModule(torch.nn.Module):
            def __init__(self):
                super(MyTestModule, self).__init__()
                self.conv_weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)))
                self.conv_bias = torch.nn.Parameter(torch.Tensor(torch.rand((conv_bias_shape))))
                self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape)))
                self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))))
                self.strides = strides
                self.paddings = paddings
                self.dilations = dilations
                self.groups = groups

            def forward(self, x):
                o = F.conv2d(x, self.conv_weight, self.conv_bias,
                             self.strides, self.paddings, self.dilations, self.groups)
                o = F.relu(o)
                x = o.permute([0, 2, 3, 1])
                o = F.linear(x, self.linear_weight, self.linear_bias)
                o = o + x
                return F.relu(o)

        class BNTestModule(torch.nn.Module):
            def __init__(self):
                super(BNTestModule, self).__init__()
                self.conv = torch.nn.Conv2d(1, 20, 5, 1)
                self.bn = torch.nn.BatchNorm2d(num_features=20)
                self.bn.eps = 0.0023

            def forward(self, x):
                x = self.conv(x)
                x = self.bn(x)
                return x

        data_shape = (batch_size, input_channels, height, width)
        input_data = torch.normal(1, 20, size=data_shape)

        scripted_model = torch.jit.script(MyTestModule())
        scripted_model.eval()
        initial_result = scripted_model(input_data)

        optimized_scripted_model = optimize_for_mobile(scripted_model)
        optimized_result = optimized_scripted_model(input_data)

        FileCheck().check_not("Tensor = aten::conv2d") \
                   .check_not("Tensor = prim::CallFunction") \
                   .check_not("prepacked::conv2d_clamp_prepack") \
                   .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
                   .check_not("prepacked::linear_clamp_prepack") \
                   .check_count("prepacked::linear_clamp_run", 1, exactly=True) \
                   .check_not("aten::add(") \
                   .check_not("aten::relu(") \
                   .check_count("aten::_add_relu(", 1, exactly=True) \
                   .run(optimized_scripted_model.graph)
        torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)


        optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
        optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack)
        optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data)

        FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \
                   .check_not("prepacked::linear_clamp_run") \
                   .check_not("prepacked::conv2d_clamp_run") \
                   .run(optimized_scripted_model_no_prepack.graph)
        torch.testing.assert_allclose(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3)


        bn_test_module = BNTestModule()
        bn_scripted_module = torch.jit.script(bn_test_module)
        bn_scripted_module.eval()
        self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 14)
        FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
                   .run(str(get_forward(bn_scripted_module._c).graph))

        optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
        bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack)
        self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1)
        bn_input = torch.rand(1, 1, 6, 6)
        torch.testing.assert_allclose(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)

        optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
        no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn)
        FileCheck().check_count("aten::batch_norm", 1, exactly=True) \
                   .run(str(get_forward_graph(no_bn_fold_scripted_module._c)))
        bn_input = torch.rand(1, 1, 6, 6)
        torch.testing.assert_allclose(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)

        class MyMobileOptimizedTagTest(torch.nn.Module):
            def __init__(self):
                super(MyMobileOptimizedTagTest, self).__init__()
                self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape)))
                self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))))

            def forward(self, x):
                o = F.linear(x, self.linear_weight, self.linear_bias)
                return F.relu(o)

        mobile_optimized_tag_module = MyMobileOptimizedTagTest()
        m = torch.jit.script(mobile_optimized_tag_module)
        m.eval()
        opt_m = optimize_for_mobile(m)
        tag = getattr(opt_m, "mobile_optimized", None)
        self.assertTrue(tag)

        class MyPreserveMethodsTest(torch.nn.Module):
            def __init__(self):
                super(MyPreserveMethodsTest, self).__init__()
                self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape)))
                self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))))

            def forward(self, x):
                o = F.linear(x, self.linear_weight, self.linear_bias)
                return F.relu(o)

            @torch.jit.export
            def preserveThis(self):
                pass

        preserve_method_module = MyPreserveMethodsTest()
        m = torch.jit.script(preserve_method_module)
        m.eval()
        opt_m = optimize_for_mobile(m)
        no_preserveThis = getattr(opt_m, "preserveThis", None)
        self.assertEqual(no_preserveThis, None)
        opt_m = optimize_for_mobile(m, preserved_methods=["preserveThis"])
        preserveThis = getattr(opt_m, "preserveThis", None)
        self.assertNotEqual(preserveThis, None)