예제 #1
0
    def _test_conv_api_impl(
        self,
        qconv_fn,
        conv_fn,
        batch_size,
        in_channels_per_group,
        input_feature_map_size,
        out_channels_per_group,
        groups,
        kernel_size,
        stride,
        padding,
        dilation,
        X_scale,
        X_zero_point,
        W_scale,
        W_zero_point,
        Y_scale,
        Y_zero_point,
        use_bias,
        use_channelwise,
    ):
        for i in range(len(kernel_size)):
            assume(input_feature_map_size[i] + 2 * padding[i] >= dilation[i] *
                   (kernel_size[i] - 1) + 1)
        (X, X_q, W, W_q, b) = _make_conv_test_input(
            batch_size, in_channels_per_group, input_feature_map_size,
            out_channels_per_group, groups, kernel_size, X_scale, X_zero_point,
            W_scale, W_zero_point, use_bias, use_channelwise)

        Y_exp = conv_fn(X, W, b, stride, padding, dilation, groups)
        Y_exp = torch.quantize_per_tensor(Y_exp,
                                          scale=Y_scale,
                                          zero_point=Y_zero_point,
                                          dtype=torch.quint8)
        Y_act = qconv_fn(X_q,
                         W_q,
                         b,
                         stride,
                         padding,
                         dilation,
                         groups,
                         padding_mode="zeros",
                         scale=Y_scale,
                         zero_point=Y_zero_point)

        # Make sure the results match
        # assert_array_almost_equal compares using the following formula:
        #     abs(desired-actual) < 1.5 * 10**(-decimal)
        # (https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_almost_equal.html)
        # We use decimal = 0 to ignore off-by-1 differences between reference
        # and test. Off-by-1 differences arise due to the order of round and
        # zero_point addition operation, i.e., if addition followed by round is
        # used by reference and round followed by addition is used by test, the
        # results may differ by 1.
        # For example, the result of round(2.5) + 1 is 3 while round(2.5 + 1) is
        # 4 assuming the rounding mode is round-to-nearest, ties-to-even.
        np.testing.assert_array_almost_equal(Y_exp.int_repr().numpy(),
                                             Y_act.int_repr().numpy(),
                                             decimal=0)
예제 #2
0
    def _test_conv_api_impl(
        self, module_name, qconv_module, conv_module, batch_size,
        in_channels_per_group, input_feature_map_size, out_channels_per_group,
        groups, kernel_size, stride, padding, dilation, X_scale, X_zero_point,
        W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused,
        use_channelwise,
    ):
        for i in range(len(kernel_size)):
            assume(input_feature_map_size[i] + 2 * padding[i]
                   >= dilation[i] * (kernel_size[i] - 1) + 1)

        in_channels = in_channels_per_group * groups
        out_channels = out_channels_per_group * groups
        (X, X_q, W, W_q, b) = _make_conv_test_input(
            batch_size, in_channels_per_group, input_feature_map_size,
            out_channels_per_group, groups, kernel_size, X_scale, X_zero_point,
            W_scale, W_zero_point, use_bias, use_channelwise)

        qconv_module.set_weight_bias(W_q, b)
        qconv_module.scale = Y_scale
        qconv_module.zero_point = Y_zero_point

        if use_fused:
            conv_module[0].weight.data = W
            if use_bias:
                conv_module[0].bias.data = b
        else:
            conv_module.weight.data = W
            if use_bias:
                conv_module.bias.data = b

        # Test members
        self.assertTrue(module_name in str(qconv_module))
        self.assertTrue(hasattr(qconv_module, '_packed_params'))
        self.assertTrue(hasattr(qconv_module, 'scale'))
        self.assertTrue(hasattr(qconv_module, 'zero_point'))

        # Test properties
        self.assertEqual(W_q, qconv_module.weight())
        if use_bias:
            self.assertEqual(b, qconv_module.bias())
        self.assertEqual(Y_scale, qconv_module.scale)
        self.assertEqual(Y_zero_point, qconv_module.zero_point)

        # Test forward
        Y_exp = conv_module(X)
        Y_exp = torch.quantize_per_tensor(
            Y_exp, scale=Y_scale, zero_point=Y_zero_point, dtype=torch.quint8)
        Y_act = qconv_module(X_q)

        # Make sure the results match
        # assert_array_almost_equal compares using the following formula:
        #     abs(desired-actual) < 1.5 * 10**(-decimal)
        # (https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_almost_equal.html)
        # We use decimal = 0 to ignore off-by-1 differences between reference
        # and test. Off-by-1 differences arise due to the order of round and
        # zero_point addition operation, i.e., if addition followed by round is
        # used by reference and round followed by addition is used by test, the
        # results may differ by 1.
        # For example, the result of round(2.5) + 1 is 3 while round(2.5 + 1) is
        # 4 assuming the rounding mode is round-to-nearest, ties-to-even.
        np.testing.assert_array_almost_equal(
            Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0)

        # Test serialization of quantized Conv Module using state_dict
        model_dict = qconv_module.state_dict()
        self.assertEqual(model_dict['weight'], W_q)
        if use_bias:
            self.assertEqual(model_dict['bias'], b)
        bytes_io = io.BytesIO()
        torch.save(model_dict, bytes_io)
        bytes_io.seek(0)
        loaded_dict = torch.load(bytes_io)
        for key in loaded_dict:
            self.assertEqual(model_dict[key], loaded_dict[key])
        loaded_qconv_module = type(qconv_module)(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            groups, use_bias, padding_mode="zeros")
        loaded_qconv_module.load_state_dict(loaded_dict)

        self.assertTrue(dir(loaded_qconv_module) == dir(qconv_module))
        self.assertTrue(module_name in str(loaded_qconv_module))
        self.assertTrue(hasattr(loaded_qconv_module, '_packed_params'))
        self.assertTrue(hasattr(loaded_qconv_module, '_weight_bias'))

        self.assertEqual(qconv_module.weight(), loaded_qconv_module.weight())
        if use_bias:
            self.assertEqual(qconv_module.bias(), loaded_qconv_module.bias())
        self.assertEqual(qconv_module.scale, loaded_qconv_module.scale)
        self.assertEqual(qconv_module.zero_point,
                         loaded_qconv_module.zero_point)
        Y_loaded = loaded_qconv_module(X_q)
        np.testing.assert_array_almost_equal(
            Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0)

        b = io.BytesIO()
        torch.save(qconv_module, b)
        b.seek(0)
        loaded_conv = torch.load(b)

        self.assertEqual(loaded_conv.bias(), qconv_module.bias())
        self.assertEqual(loaded_conv.scale, qconv_module.scale)
        self.assertEqual(loaded_conv.zero_point,
                         qconv_module.zero_point)

        # JIT testing
        self.checkScriptable(
            qconv_module, [[X_q]],
            check_save_load=True)

        # Test from_float
        conv_module.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(conv_module, inplace=True)
        conv_module(X.float())
        converted_qconv_module = torch.nn.Sequential(conv_module)
        torch.quantization.convert(converted_qconv_module, inplace=True)

        # Smoke test to make sure the module actually runs
        if use_bias:
            if use_fused:
                self.assertEqual(conv_module[0].bias,
                                 converted_qconv_module[0].bias())
            else:
                self.assertEqual(conv_module.bias,
                                 converted_qconv_module[0].bias())
        # Smoke test extra_repr
        self.assertTrue(module_name in str(converted_qconv_module))