def test_insert_quant_dequant_conv_dynamic(self): class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.conv = torch.nn.Conv2d(3, 5, 3).float() def forward(self, x): return self.conv(x) m = torch.jit.script(M()) m = prepare_dynamic_script(m, {'': default_dynamic_qconfig}) data = torch.randn(1, 3, 10, 10, dtype=torch.float) m(data) m = wrap_cpp_module( torch._C._jit_pass_insert_quant_dequant(m._c, "forward", False, True)) assert len(m._modules._c.items()) == 1, \ 'Expected to have single submodule of conv' m(data) quant_func = "aten::quantize_per_tensor" # quantizing activations FileCheck().check("aten::_choose_qparams_per_tensor") \ .check(quant_func) \ .check("prim::CallMethod[name=\"forward\"]") \ .check_not(quant_func) \ .check("return") \ .run(str(get_forward_graph(m._c))) # quantizing weight in forward function of conv module, no choose_qparams FileCheck().check_not("aten::_choose_qparams_per_tensor") \ .check(quant_func) \ .check("prim::CallMethod[name=\"_conv_forward\"]") \ .check_not(quant_func) \ .check("return") \ .run(str(get_forward_graph(m.conv._c))) # shouldn't have quant/dequant in _conv_foward function FileCheck().check_not(quant_func) \ .check("aten::conv2d") \ .check_not(quant_func) \ .check("return") \ .run(str(get_module_method(m, 'conv', '_conv_forward').graph))
def test_insert_quant_dequant_linear_dynamic(self): class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.fc1 = torch.nn.Linear(5, 5).float() self.fc2 = torch.nn.Linear(5, 5).float() def forward(self, x): x = self.fc1(x) return self.fc2(x) m = torch.jit.script(M()) m = prepare_dynamic_script(m, {'': default_dynamic_qconfig}) data = torch.randn(5, 5, dtype=torch.float) m(data) m = wrap_cpp_module( torch._C._jit_pass_insert_quant_dequant(m._c, "forward", False, True)) assert len(m._modules._c.items()) == 2, \ 'Expected to have two submodule of linear' m(data) quant_func = "aten::quantize_per_tensor" # quantizing activations FileCheck().check("aten::_choose_qparams_per_tensor") \ .check(quant_func) \ .check("prim::CallMethod[name=\"forward\"]") \ .check("aten::_choose_qparams_per_tensor") \ .check(quant_func) \ .check("prim::CallMethod[name=\"forward\"]") \ .check_not(quant_func) \ .check("return") \ .run(str(get_forward_graph(m._c))) # quantizing weight in forward function of fc module, no choose_qparams FileCheck().check_not("aten::_choose_qparams_per_tensor") \ .check(quant_func) \ .check("prim::CallFunction") \ .check_not(quant_func) \ .check("return") \ .run(str(get_forward_graph(m.fc1._c)))
def test_optimize_for_mobile(self): batch_size = 2 input_channels_per_group = 6 height = 16 width = 16 output_channels_per_group = 6 groups = 4 kernel_h = kernel_w = 3 stride_h = stride_w = 1 pad_h = pad_w = 1 dilation = 1 input_channels = input_channels_per_group * groups output_channels = output_channels_per_group * groups kernels = (kernel_h, kernel_w) strides = (stride_h, stride_w) paddings = (pad_h, pad_w) dilations = (dilation, dilation) conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) conv_bias_shape = (output_channels) input_data = torch.rand((batch_size, input_channels, height, width)) conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w)) conv_bias = torch.rand((output_channels)) result = F.conv2d(input_data, conv_weight, conv_bias, strides, paddings, dilations, groups) weight_output_dim = 24 linear_input_shape = result.shape[1] linear_weight_shape = (weight_output_dim, linear_input_shape) class MyTestModule(torch.nn.Module): def __init__(self): super(MyTestModule, self).__init__() self.conv_weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape))) self.conv_bias = torch.nn.Parameter(torch.Tensor(torch.rand((conv_bias_shape)))) self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape))) self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim)))) self.strides = strides self.paddings = paddings self.dilations = dilations self.groups = groups def forward(self, x): o = F.conv2d(x, self.conv_weight, self.conv_bias, self.strides, self.paddings, self.dilations, self.groups) o = F.relu(o) x = o.permute([0, 2, 3, 1]) o = F.linear(x, self.linear_weight, self.linear_bias) o = o + x return F.relu(o) class BNTestModule(torch.nn.Module): def __init__(self): super(BNTestModule, self).__init__() self.conv = torch.nn.Conv2d(1, 20, 5, 1) self.bn = torch.nn.BatchNorm2d(num_features=20) self.bn.eps = 0.0023 def forward(self, x): x = self.conv(x) x = self.bn(x) return x data_shape = (batch_size, input_channels, height, width) input_data = torch.normal(1, 20, size=data_shape) scripted_model = torch.jit.script(MyTestModule()) scripted_model.eval() initial_result = scripted_model(input_data) optimized_scripted_model = optimize_for_mobile(scripted_model) optimized_result = optimized_scripted_model(input_data) FileCheck().check_not("Tensor = aten::conv2d") \ .check_not("Tensor = prim::CallFunction") \ .check_not("prepacked::conv2d_clamp_prepack") \ .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \ .check_not("prepacked::linear_clamp_prepack") \ .check_count("prepacked::linear_clamp_run", 1, exactly=True) \ .check_not("aten::add(") \ .check_not("aten::relu(") \ .check_count("aten::_add_relu(", 1, exactly=True) \ .run(optimized_scripted_model.graph) torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3) optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack) optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data) FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \ .check_not("prepacked::linear_clamp_run") \ .check_not("prepacked::conv2d_clamp_run") \ .run(optimized_scripted_model_no_prepack.graph) torch.testing.assert_allclose(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3) bn_test_module = BNTestModule() bn_scripted_module = torch.jit.script(bn_test_module) bn_scripted_module.eval() self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 14) FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \ .run(str(get_forward(bn_scripted_module._c).graph)) optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack) self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1) bn_input = torch.rand(1, 1, 6, 6) torch.testing.assert_allclose(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3) optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION} no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn) FileCheck().check_count("aten::batch_norm", 1, exactly=True) \ .run(str(get_forward_graph(no_bn_fold_scripted_module._c))) bn_input = torch.rand(1, 1, 6, 6) torch.testing.assert_allclose(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3) class MyMobileOptimizedTagTest(torch.nn.Module): def __init__(self): super(MyMobileOptimizedTagTest, self).__init__() self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape))) self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim)))) def forward(self, x): o = F.linear(x, self.linear_weight, self.linear_bias) return F.relu(o) mobile_optimized_tag_module = MyMobileOptimizedTagTest() m = torch.jit.script(mobile_optimized_tag_module) m.eval() opt_m = optimize_for_mobile(m) tag = getattr(opt_m, "mobile_optimized", None) self.assertTrue(tag) class MyPreserveMethodsTest(torch.nn.Module): def __init__(self): super(MyPreserveMethodsTest, self).__init__() self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape))) self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim)))) def forward(self, x): o = F.linear(x, self.linear_weight, self.linear_bias) return F.relu(o) @torch.jit.export def preserveThis(self): pass preserve_method_module = MyPreserveMethodsTest() m = torch.jit.script(preserve_method_module) m.eval() opt_m = optimize_for_mobile(m) no_preserveThis = getattr(opt_m, "preserveThis", None) self.assertEqual(no_preserveThis, None) opt_m = optimize_for_mobile(m, preserved_methods=["preserveThis"]) preserveThis = getattr(opt_m, "preserveThis", None) self.assertNotEqual(preserveThis, None)