Exemple #1
0
                def process_weights(ihhh, layer, suffix, dtype):
                    weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix)
                    bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix)

                    weight = getattr(mod, weight_name)
                    bias = getattr(mod, bias_name)

                    if dtype == torch.qint8:
                        # for each layer, for each direction we need to quantize and pack
                        # weights and pack parameters in this order:
                        #
                        #   w_ih, w_hh
                        weight_observer = weight_observer_method()
                        weight_observer(weight)
                        wt_scale, wt_zp = weight_observer.calculate_qparams()
                        qweight = torch.quantize_per_tensor(
                            weight.float(), float(wt_scale), int(wt_zp), torch.qint8)
                        packed_weight = \
                            torch.ops.quantized.linear_prepack(qweight, bias)

                        params = [packed_weight]
                        pos_names = ['w']
                        ret_name = ['{}_{}_l{}{}'.format(
                            name, ihhh, layer, suffix) for name in pos_names]
                        return params, ret_name
                    else:
                        # for each layer, for each direction we need to quantize and pack
                        # weights and pack parameters in this order:
                        #
                        #   packed_ih, packed_hh, b_ih, b_hh
                        packed_weight = torch.fbgemm_pack_gemm_matrix_fp16(
                            weight.float())

                        params = [packed_weight, bias]
                        pos_names = ['packed', 'b']
                        ret_name = ['{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names]
                        return params, ret_name
    def test_batchnorm_basic(self):
        """
        Basic test of the PyTorch 3D batchnorm Node on Glow.
        """

        class SimpleQuantizedBatchNorm(nn.Module):
            def __init__(self, C, running_mean, running_var, scale, zero_point):
                super(SimpleQuantizedBatchNorm, self).__init__()
                self.qconfig = my_qconfig
                self.batchnorm = nn.quantized.BatchNorm3d(C)
                self.batchnorm.scale = scale
                self.batchnorm.zero_point = zero_point
                self.batchnorm.running_mean = running_mean
                self.batchnorm.running_var = running_var
                self.relu = torch.nn.ReLU()
                self.dq = torch.nn.quantized.DeQuantize()

            def forward(self, x):
                return self.dq(self.relu(self.batchnorm(x)))

        C = 4
        in_scale = out_scale = 0.004
        in_zero_point = out_zero_point = 4
        running_mean = torch.zeros(C)
        running_var = torch.ones(C)

        inputs = torch.randn((5, C, 6, 32, 73), requires_grad=False)
        inputs = torch.quantize_per_tensor(
            inputs, scale=in_scale, zero_point=in_zero_point, dtype=torch.qint8
        )
        model = SimpleQuantizedBatchNorm(
            C, running_mean, running_var, out_scale, out_zero_point
        )
        model.eval()

        torch_glow.enable_convert_to_fp16()
        jitVsGlow(model, inputs, expected_fused_ops={"quantized::batch_norm3d"})
Exemple #3
0
    def test_qtensor_float_assignment(self):
        # Scalar Tensor
        # item
        scale = 1.0
        zero_point = 2
        r = torch.ones(1, dtype=torch.float)
        for dtype in [torch.qint8, torch.quint8, torch.qint32]:
            qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
            self.assertEqual(qr.item(), 1)
            self.assertEqual(qr[0].item(), 1)
            # assignment
            self.assertTrue(qr[0].is_quantized)
            qr[0] = 11.3  # float assignment
            self.assertEqual(qr.item(), 11)
            x = torch.ones(1, dtype=torch.float) * 15.3
            # Copying from a float Tensor
            qr[:] = x
            self.assertEqual(qr.item(), 15)

            dtype_msg = str(dtype) + ", "
            self.assertEqual(
                ' '.join(str(qr).split()), "tensor([15.], size=(1,), dtype=" +
                dtype_msg + "quantization_scheme=torch.per_tensor_affine, " +
                "scale=1.0, zero_point=2)")
    def init(self, N, dtype, contig, other_scalar, out_variant, op_func):
        # TODO: Consider more diverse shapes
        f_input = (torch.rand(N, N) - 0.5) * 256
        scale = 1.0
        zero_point = 0

        q_input_a = torch.quantize_per_tensor(f_input,
                                              scale=scale,
                                              zero_point=zero_point,
                                              dtype=dtype)
        if other_scalar:
            q_input_b = 42
        else:
            q_input_b = q_input_a.clone()

        if not contig:
            permute_dims = list(range(f_input.ndim))[::-1]
            q_input_a = q_input_a.permute(permute_dims)

        self.qop = op_func
        self.args = (q_input_a, q_input_b)
        self.kwargs = {}
        if out_variant:
            self.kwargs['out'] = torch.tensor([], dtype=torch.bool)
Exemple #5
0
    def _test_numerical_consistency(self, test_type):
        r"""Comparing numerical consistency between quantize/dequantize op and the fake quantize op across devices and dtypes
        """
        torch.random.manual_seed(NP_RANDOM_SEED)
        torch_types = [torch.qint8, torch.quint8]
        float_types = [torch.float, torch.float16, torch.float64]
        zero_types = [torch.int]
        devices = [torch.device('cpu'), torch.device('cuda')] if torch.cuda.is_available() else [torch.device('cpu')]
        axis = 1
        for i in range(20):
            for torch_type, float_type, device, zero_type in itertools.product(torch_types, float_types, devices, zero_types):
                X = torch.randn(3, 3, device=device).to(float_type)
                scales = (10 * torch.randn(3, device=device)).abs()
                scale = scales.mean().to(float).item()
                zeros = (10 * torch.randn(3, device=device)).abs().to(dtype=zero_type)
                zero = zeros.max().view(1).item()
                quant_min = torch.iinfo(torch_type).min
                quant_max = torch.iinfo(torch_type).max

                test_was_run = False
                if test_type == "per_tensor":
                    test_was_run = True
                    Y = torch.dequantize(torch.quantize_per_tensor(X.to('cpu').to(torch.float),
                                                                   scale, zero, torch_type)).to(device).to(float_type)
                    Y_prime = torch.fake_quantize_per_tensor_affine(X, scale, zero, quant_min, quant_max)
                    self.assertEqual(
                        Y, Y_prime, "Difference found between dequant+quant_per_tensor and fake_quantize_per_tensor")

                if test_type == "per_channel":
                    test_was_run = True
                    Y = torch.dequantize(torch.quantize_per_channel(X.to('cpu').to(torch.float), scales.to(
                        'cpu'), zeros.to('cpu'), axis, torch_type)).to(device).to(float_type)
                    Y_prime = torch.fake_quantize_per_channel_affine(X, scales, zeros, axis, quant_min, quant_max)
                    self.assertEqual(
                        Y, Y_prime, "Difference found between dequant+quant_per_channel and fake_quantize_per_channel")
                self.assertTrue(test_was_run)
Exemple #6
0
def quantize_node(node, activation_post_process):
    scale, zero_point, dtype = get_qparams(activation_post_process)
    return torch.quantize_per_tensor(node, scale, zero_point, dtype)
    def test_serialize_graph(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)
                self.e = torch.rand(4)
                self.conv = torch.nn.Conv2d(3, 3, 2, bias=False)

            def forward(self, a, b, c):
                add_1 = a + b
                conv1 = self.conv(c)
                linear = self.linear(add_1 + conv1)
                add_2 = linear + self.e
                return add_2

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        c = torch.rand(3, 3, 2, 2)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b, c])

        partitioner = Partitioner()
        devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)]
        partitioner_config = PartitionerConfig(devices,
                                               PartitionMode.sparse_nn)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        # Fix for now to add type/shape to output
        for node in traced.graph.nodes:
            if node.op == "output":
                node.meta['tensor_meta'] = extract_tensor_metadata(a)
        for mod in module_with_submodules.modules():
            if isinstance(mod, GraphModule):
                for node in mod.graph.nodes:
                    node.meta['tensor_meta'] = extract_tensor_metadata(a)
        for node in module_with_submodules.graph.nodes:
            node.meta['tensor_meta'] = extract_tensor_metadata(a)

        weights1 = {}
        weights2 = {}
        serialized_graph1 = graph_manipulation.serialize_module(
            traced, weights1)
        serialized_graph2 = graph_manipulation.serialize_module(
            module_with_submodules, weights2)
        assert len(weights1) == 4
        assert len(weights2) == 4
        assert len(serialized_graph1["nodes"]) == 10
        assert len(serialized_graph1["weights"]) == 4
        assert len(serialized_graph1["modules"]) == 0
        assert len(serialized_graph2["nodes"]) == 6
        assert len(serialized_graph2["weights"]) == 4
        assert len(serialized_graph2["modules"]) == 1
        assert serialized_graph1["weights"]["linear.weight"][
            "shape"] == "[4, 4]"
        assert (serialized_graph1["weights"]["linear.weight"]["dtype"] ==
                "torch.float32")
        assert (serialized_graph1["weights"]["linear.weight"]["is_quantized"]
                is False)
        assert serialized_graph1["nodes"][0]["shape"] == "[4]"
        assert serialized_graph1["nodes"][0]["dtype"] == "torch.float32"
        assert serialized_graph1["nodes"][0]["target"] == "a"
        assert serialized_graph1["nodes"][0]["op_code"] == "placeholder"
        assert serialized_graph1["nodes"][0]["name"] == "a"
        assert serialized_graph1["nodes"][6]["args"][0]["name"] == "add_1"
        assert serialized_graph1["nodes"][6]["args"][0]["is_node"] is True

        # Test quantization info serialization.
        x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
        q_tensor = torch.quantize_per_tensor(x, 1, 0, torch.qint32)
        q_tensor_channel = torch.quantize_per_channel(
            x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0,
            torch.quint8)
        result = graph_manipulation.serialize_tensor_quantization(q_tensor)
        result2 = graph_manipulation.serialize_tensor_quantization(
            q_tensor_channel)
        assert result["qscheme"] == "torch.per_tensor_affine"
        assert result["q_scale"] == 1.0
        assert result2["qscheme"] == "torch.per_channel_affine"
        assert len(result2["q_per_channel_scales"]) == 2
    def test_linear_api(self, batch_size, in_features, out_features, use_bias,
                        use_fused):
        """test API functionality for nn.quantized.linear and nn._intrinsic.quantized.linear_relu"""
        W = torch.rand(out_features, in_features).float()
        W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8)
        X = torch.rand(batch_size, in_features).float()
        X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8)
        B = torch.rand(out_features).float() if use_bias else None
        scale = 0.5
        zero_point = 3
        if use_fused:
            qlinear = nnq_fused.LinearReLU(in_features, out_features)
        else:
            qlinear = nnq.Linear(in_features, out_features)

        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        qlinear(X_q)

        qlinear.set_weight_bias(W_q, B)
        # Simple round-trip test to ensure weight()/set_weight() API
        self.assertEqual(qlinear.weight(), W_q)
        W_pack = qlinear._packed_params

        qlinear.scale = float(scale)
        qlinear.zero_point = int(zero_point)
        Z_q = qlinear(X_q)
        # Check if the module implementation matches calling the
        # ops directly
        if use_fused:
            Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale,
                                                    zero_point)
        else:
            Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point)
        self.assertEqual(Z_ref, Z_q)

        # Test serialization of quantized Linear Module using state_dict

        model_dict = qlinear.state_dict()
        self.assertEqual(model_dict['weight'], W_q)
        if use_bias:
            self.assertEqual(model_dict['bias'], B)
        b = io.BytesIO()
        torch.save(model_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in model_dict:
            self.assertEqual(model_dict[key], loaded_dict[key])
        if use_fused:
            loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features)
        else:
            loaded_qlinear = nnq.Linear(in_features, out_features)
        loaded_qlinear.load_state_dict(loaded_dict)

        linear_unpack = torch.ops.quantized.linear_unpack
        self.assertEqual(linear_unpack(qlinear._packed_params),
                         linear_unpack(loaded_qlinear._packed_params))
        if use_bias:
            self.assertEqual(qlinear.bias(), loaded_qlinear.bias())
        self.assertEqual(qlinear.scale, loaded_qlinear.scale)
        self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
        self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
        self.assertTrue(hasattr(qlinear, '_packed_params'))
        self.assertTrue(hasattr(loaded_qlinear, '_packed_params'))
        self.assertTrue(hasattr(qlinear, '_weight_bias'))
        self.assertTrue(hasattr(loaded_qlinear, '_weight_bias'))
        self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
        self.assertEqual(
            qlinear._weight_bias(),
            torch.ops.quantized.linear_unpack(qlinear._packed_params))
        Z_q2 = loaded_qlinear(X_q)
        self.assertEqual(Z_q, Z_q2)

        # The below check is meant to ensure that `torch.save` and `torch.load`
        # serialization works, however it is currently broken by the following:
        # https://github.com/pytorch/pytorch/issues/24045
        #
        # Instead, we currently check that the proper exception is thrown on save.
        # <start code>
        # b = io.BytesIO()
        # torch.save(qlinear, b)
        # b.seek(0)
        # loaded = torch.load(b)
        # self.assertEqual(qlinear.weight(), loaded.weight())
        # self.assertEqual(qlinear.scale, loaded.scale)
        # self.assertEqual(qlinear.zero_point, loaded.zero_point)
        # <end code>
        with self.assertRaisesRegex(
                RuntimeError, r'torch.save\(\) is not currently supported'):
            b = io.BytesIO()
            torch.save(qlinear, b)

        # Test JIT
        self.checkScriptable(qlinear,
                             list(zip([X_q], [Z_ref])),
                             check_save_load=True)

        # Test from_float.
        float_linear = torch.nn.Linear(in_features, out_features).float()
        float_linear.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(float_linear, inplace=True)
        float_linear(X.float())
        # Sequential allows swapping using "convert".
        quantized_float_linear = torch.nn.Sequential(float_linear)
        quantized_float_linear = torch.quantization.convert(
            quantized_float_linear, inplace=True)

        # Smoke test to make sure the module actually runs
        quantized_float_linear(X_q)

        # Smoke test extra_repr
        str(quantized_float_linear)
Exemple #9
0
    def _forward_impl(
            self,
            query: Tensor,
            key: Tensor,
            value: Tensor,
            key_padding_mask: Optional[Tensor] = None,
            need_weights: bool = True,
            attn_mask: Optional[Tensor] = None
    ) -> Tuple[Tensor, Optional[Tensor]]:
        # This version will not deal with the static key/value pairs.
        # Keeping it here for future changes.
        #
        # TODO: This method has some duplicate lines with the
        # `torch.nn.functional.multi_head_attention`. Will need to refactor.
        static_k = None
        static_v = None

        tgt_len, bsz, embed_dim_to_check = query.size()
        assert self.embed_dim == embed_dim_to_check
        # allow MHA to have different sizes for the feature dimension
        assert key.size(0) == value.size(0) and key.size(1) == value.size(1)

        head_dim = self.embed_dim // self.num_heads
        assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        scaling = float(head_dim)**-0.5

        q = self.linear_Q(query)
        k = self.linear_K(key)
        v = self.linear_V(value)

        q = self.q_scaling_product.mul_scalar(q, scaling)

        if attn_mask is not None:
            assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
                attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
                'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
            if attn_mask.dtype == torch.uint8:
                warnings.warn(
                    "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
                )
                attn_mask = attn_mask.to(torch.bool)

            if attn_mask.dim() == 2:
                attn_mask = attn_mask.unsqueeze(0)
                if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
                    raise RuntimeError(
                        'The size of the 2D attn_mask is not correct.')
            elif attn_mask.dim() == 3:
                if list(attn_mask.size()) != [
                        bsz * self.num_heads,
                        query.size(0),
                        key.size(0)
                ]:
                    raise RuntimeError(
                        'The size of the 3D attn_mask is not correct.')
            else:
                raise RuntimeError(
                    "attn_mask's dimension {} is not supported".format(
                        attn_mask.dim()))
            # attn_mask's dim is 3 now.

        # convert ByteTensor key_padding_mask to bool
        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
            warnings.warn(
                "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
            )
            key_padding_mask = key_padding_mask.to(torch.bool)
        if self.bias_k is not None and self.bias_v is not None:
            if static_k is None and static_v is None:
                k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
                v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
                if attn_mask is not None:
                    attn_mask = nnF.pad(attn_mask, (0, 1))
                if key_padding_mask is not None:
                    key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
            else:
                assert static_k is None, "bias cannot be added to static key."
                assert static_v is None, "bias cannot be added to static value."
        else:
            assert self.bias_k is None
            assert self.bias_v is None

        q = q.contiguous().view(tgt_len, bsz * self.num_heads,
                                head_dim).transpose(0, 1)
        if k is not None:
            k = k.contiguous().view(-1, bsz * self.num_heads,
                                    head_dim).transpose(0, 1)
        if v is not None:
            v = v.contiguous().view(-1, bsz * self.num_heads,
                                    head_dim).transpose(0, 1)

        if static_k is not None:
            assert static_k.size(0) == bsz * self.num_heads
            assert static_k.size(2) == head_dim
            k = static_k

        if static_v is not None:
            assert static_v.size(0) == bsz * self.num_heads
            assert static_v.size(2) == head_dim
            v = static_v

        src_len = k.size(1)

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == src_len

        if self.add_zero_attn:
            src_len += 1
            k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
            if k.is_quantized:
                k_zeros = torch.quantize_per_tensor(k_zeros, k.q_scale(),
                                                    k.q_zero_point(), k.dtype)
            k = torch.cat([k, k_zeros], dim=1)
            v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
            if v.is_quantized:
                v_zeros = torch.quantize_per_tensor(v_zeros, v.q_scale(),
                                                    v.q_zero_point(), v.dtype)
            v = torch.cat([v, v_zeros], dim=1)

            if attn_mask is not None:
                attn_mask = nnF.pad(attn_mask, (0, 1))
            if key_padding_mask is not None:
                key_padding_mask = nnF.pad(key_padding_mask, (0, 1))

        # Leaving the quantized zone here
        q = self.dequant_q(q)
        k = self.dequant_k(k)
        v = self.dequant_v(v)
        attn_output_weights = torch.bmm(q, k.transpose(1, 2))
        assert list(attn_output_weights.size()) == [
            bsz * self.num_heads, tgt_len, src_len
        ]

        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                attn_output_weights.masked_fill_(attn_mask, float('-inf'))
            else:
                attn_output_weights += attn_mask

        if key_padding_mask is not None:
            attn_output_weights = attn_output_weights.view(
                bsz, self.num_heads, tgt_len, src_len)
            attn_output_weights = attn_output_weights.masked_fill(
                key_padding_mask.unsqueeze(1).unsqueeze(2),
                float('-inf'),
            )
            attn_output_weights = attn_output_weights.view(
                bsz * self.num_heads, tgt_len, src_len)

        attn_output_weights = nnF.softmax(attn_output_weights, dim=-1)
        attn_output_weights = nnF.dropout(attn_output_weights,
                                          p=self.dropout,
                                          training=self.training)

        attn_output = torch.bmm(attn_output_weights, v)
        assert list(
            attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
        attn_output = attn_output.transpose(0, 1).contiguous().view(
            tgt_len, bsz, self.embed_dim)

        # Reentering the quantized zone
        attn_output = self.quant_attn_output(attn_output)
        attn_output = self.out_proj(attn_output)  # type: ignore
        attn_output_weights = self.quant_attn_output_weights(
            attn_output_weights)

        if need_weights:
            # average attention weights over heads
            attn_output_weights = attn_output_weights.view(
                bsz, self.num_heads, tgt_len, src_len)
            return attn_output, attn_output_weights.mean(dim=1)
        else:
            return attn_output, None
Exemple #10
0
    def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_default_observer):
        """test API functionality for nn.quantized.dynamic.Linear"""
        W = torch.rand(out_features, in_features).float()
        W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8)
        W_q = torch.quantize_per_tensor(W, W_scale, W_zp, torch.qint8)
        X = torch.rand(batch_size, in_features).float()
        B = torch.rand(out_features).float() if use_bias else None
        qlinear = nnqd.Linear(in_features, out_features)
        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        qlinear.set_weight_bias(W_q, B)
        qlinear(X)

        # Simple round-trip test to ensure weight()/set_weight() API
        self.assertEqual(qlinear.weight(), W_q)
        W_pack = qlinear._packed_params._packed_params
        Z_dq = qlinear(X)

        # Check if the module implementation matches calling the
        # ops directly
        Z_ref = torch.ops.quantized.linear_dynamic(X, W_pack, reduce_range=True)
        self.assertEqual(Z_ref, Z_dq)

        # Test serialization of dynamic quantized Linear Module using state_dict
        model_dict = qlinear.state_dict()
        b = io.BytesIO()
        torch.save(model_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in model_dict:
            if isinstance(model_dict[key], torch._C.ScriptObject):
                assert isinstance(loaded_dict[key], torch._C.ScriptObject)
                w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key])
                w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key])
                self.assertEqual(w_model, w_loaded)
                self.assertEqual(b_model, b_loaded)
            else:
                self.assertEqual(model_dict[key], loaded_dict[key])
        loaded_qlinear = nnqd.Linear(in_features, out_features)
        loaded_qlinear.load_state_dict(loaded_dict)

        linear_unpack = torch.ops.quantized.linear_unpack
        self.assertEqual(linear_unpack(qlinear._packed_params._packed_params),
                         linear_unpack(loaded_qlinear._packed_params._packed_params))
        if use_bias:
            self.assertEqual(qlinear.bias(), loaded_qlinear.bias())
        self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
        self.assertTrue(hasattr(qlinear, '_packed_params'))
        self.assertTrue(hasattr(loaded_qlinear, '_packed_params'))
        self.assertTrue(hasattr(qlinear, '_weight_bias'))
        self.assertTrue(hasattr(loaded_qlinear, '_weight_bias'))

        self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
        self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
        Z_dq2 = qlinear(X)
        self.assertEqual(Z_dq, Z_dq2)

        b = io.BytesIO()
        torch.save(qlinear, b)
        b.seek(0)
        loaded = torch.load(b)
        self.assertEqual(qlinear.weight(), loaded.weight())
        self.assertEqual(qlinear.zero_point, loaded.zero_point)

        # Test JIT
        self.checkScriptable(qlinear, [[X]], check_save_load=True)

        # Test from_float
        float_linear = torch.nn.Linear(in_features, out_features).float()
        if use_default_observer:
            float_linear.qconfig = torch.quantization.default_dynamic_qconfig
        prepare_dynamic(float_linear)
        float_linear(X.float())
        quantized_float_linear = nnqd.Linear.from_float(float_linear)

        # Smoke test to make sure the module actually runs
        quantized_float_linear(X)

        # Smoke test extra_repr
        self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
Exemple #11
0
    def op_convert_before_hook(
        self,
        op: Callable,
        args: Tuple[Any, ...],
        kwargs: Dict[str, Any],
        root_module: torch.nn.Module,
    ) -> Tuple[Callable, Tuple[Any, ...], Dict[str, Any]]:
        """
        This function is called before an op call in a converted model.

        For each arg in `args`, quantizes it if necessary.

        Returns potentially modified `op`, potentially modified `args`,
        potentially modified `kwargs`.
        """
        # TODO generalize this for more things
        # currently:
        # * can quantize args (via arg_quant_infos)
        # * can add scale and zp (via additional kwargs)

        # needed for F.conv2d
        # F.conv2d(input, weight, bias, stride, padding, dilation, groups)
        #   to
        # q.conv2d(input, packed_params, scale, zero_point)
        orig_op = op
        maybe_new_op, arg_quant_infos, arg_dequant_infos, packed_param_name, \
            additional_kwargs, any_arg_quant_or_dequant_needed, \
            any_arg_kwarg_modification_needed = self.get_op_convert_info(op)
        if maybe_new_op is not None:
            op = maybe_new_op
        if not any_arg_kwarg_modification_needed:
            return op, args, kwargs
        # print(op, arg_quant_infos, packed_param_name, additional_kwargs)

        # potentially quantize args, based on arg_quant_infos
        new_args = []
        if any_arg_quant_or_dequant_needed:
            tensor_arg_idx = 0
            # TODO: refactor this to use iterate_and_apply
            if orig_op is torch.cat:  # torch.cat variants
                # input tensors
                new_first_arg = []
                for arg in args[0]:
                    # TODO: handle non-tensor inputs
                    quant_info = arg_quant_infos[tensor_arg_idx]
                    dequant_info = arg_dequant_infos[tensor_arg_idx]
                    if quant_info is not None:
                        scale, zp = quant_info
                        arg = torch.quantize_per_tensor(
                            arg, scale, zp, torch.quint8)
                    elif dequant_info is True:
                        arg = arg.dequantize()
                    new_first_arg.append(arg)
                    tensor_arg_idx += 1
                new_args = [new_first_arg, *args[1:]]
            else:
                for arg in args:
                    # TODO: handle non-tensor inputs
                    # TODO: this is not handling non-tensor tuple args (for example,
                    # dilation in conv2d) correctly, it just happens to work but
                    # needs a fix.
                    quant_info = arg_quant_infos[tensor_arg_idx]
                    dequant_info = arg_dequant_infos[tensor_arg_idx]
                    if quant_info is not None:
                        scale, zp = quant_info
                        arg = torch.quantize_per_tensor(
                            arg, scale, zp, torch.quint8)
                    elif dequant_info is True:
                        arg = arg.dequantize()
                    new_args.append(arg)
                    tensor_arg_idx += 1
        else:
            new_args = [*args]

        # if there is a packed param, replace the relevant args
        if packed_param_name is not None:
            new_args_with_packed = []
            packable_arg_idxs = get_packable_arg_idxs(orig_op)
            added_packed = False
            for idx, arg in enumerate(new_args):
                if packable_arg_idxs is not None and idx in packable_arg_idxs:
                    if not added_packed:
                        packed_param = getattr(root_module, packed_param_name)
                        new_args_with_packed.append(packed_param)
                        added_packed = True
                else:
                    new_args_with_packed.append(arg)
            new_args = new_args_with_packed

        # potentially extend kwargs with scale and zero_point
        # TODO move op-specific logic out of here
        if len(additional_kwargs):
            if orig_op not in (F.conv2d, F.linear):
                kwargs.update(**additional_kwargs)
            else:
                seen_op_info = self._get_cur_seen_op_info()
                if seen_op_info.output_tensor_infos[
                        0].inf_dtype == torch.quint8:
                    new_args.append(additional_kwargs['scale'])
                    new_args.append(additional_kwargs['zero_point'])

        # TODO move op-specific logic out of here
        if op is torch.ops.quantized.linear:
            kwargs.pop('bias', None)

        return op, tuple(new_args), kwargs
strides = (1, 1)
pads = (0, 0)
o_pads = (0, 0)
dilations = (1, 1)
groups = 2

y_s = 4.2
y_zp = 0

#################

X = torch.from_numpy(X_np).to(torch.float)
W = torch.from_numpy(W_np).to(torch.float)

X_q = torch.quantize_per_tensor(X, X_s, X_zp, X_dtype)
W_q = torch.quantize_per_tensor(W, W_s, W_zp, W_dtype)

X_dq = X_q.dequantize()
W_dq = W_q.dequantize()

print(f'Input shape: {X.shape}, Weight shape: {W.shape}')

#################

iC = W_dq.shape[0]
oC = W_dq.shape[1] * groups
kernel_size = W_dq.shape[2:]

conv_ref = nn.ConvTranspose2d(iC,
                              oC,
Exemple #13
0
    def _process_layer(self, layer: torch.Tensor) -> torch.Tensor:

        layer = torch.quantize_per_tensor(layer.float(), self.scale,
                                          self.zero_point, self.dtype)

        return layer
    def _test_op(self,
                 qmodule,
                 subname=None,
                 input_size=None,
                 input_quantized=True,
                 generate=False,
                 prec=None,
                 new_zipfile_serialization=False):
        r""" Test quantized modules serialized previously can be loaded
        with current code, make sure we don't break backward compatibility for the
        serialization of quantized modules
        """
        def remove_prefix(text, prefix):
            if text.startswith(prefix):
                return text[len(prefix):]
            return text

        # NB: we take __file__ from the module that defined the test
        # class, so we place the expect directory where the test script
        # lives, NOT where test/common_utils.py lives.
        module_id = self.__class__.__module__
        munged_id = remove_prefix(self.id(), module_id + ".")
        test_file = os.path.realpath(sys.modules[module_id].__file__)
        base_name = os.path.join(os.path.dirname(test_file), "serialized",
                                 munged_id)

        subname_output = ""
        if subname:
            base_name += "_" + subname
            subname_output = " ({})".format(subname)

        input_file = base_name + ".input.pt"
        state_dict_file = base_name + ".state_dict.pt"
        scripted_module_file = base_name + ".scripted.pt"
        traced_module_file = base_name + ".traced.pt"
        expected_file = base_name + ".expected.pt"

        # only generate once.
        if generate and qengine_is_fbgemm():
            input_tensor = torch.rand(*input_size).float()
            if input_quantized:
                input_tensor = torch.quantize_per_tensor(
                    input_tensor, 0.5, 2, torch.quint8)
            torch.save(input_tensor, input_file)
            # Temporary fix to use _use_new_zipfile_serialization until #38379 lands.
            torch.save(
                qmodule.state_dict(),
                state_dict_file,
                _use_new_zipfile_serialization=new_zipfile_serialization)
            torch.jit.save(torch.jit.script(qmodule), scripted_module_file)
            torch.jit.save(torch.jit.trace(qmodule, input_tensor),
                           traced_module_file)
            torch.save(qmodule(input_tensor), expected_file)

        input_tensor = torch.load(input_file)
        qmodule.load_state_dict(torch.load(state_dict_file))
        qmodule_scripted = torch.jit.load(scripted_module_file)
        qmodule_traced = torch.jit.load(traced_module_file)
        expected = torch.load(expected_file)
        self.assertEqual(qmodule(input_tensor), expected, atol=prec)
        self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec)
        self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec)
Exemple #15
0
reveal_type(torch.eye(3))  # E: {Tensor}

# torch.empty/empty_like/empty_strided
reveal_type(torch.empty(2, 3))  # E: {Tensor}
reveal_type(torch.empty_like(torch.empty(2, 3),
                             dtype=torch.int64))  # E: {Tensor}
reveal_type(torch.empty_strided((2, 3), (1, 2)))  # E: {Tensor}

# torch.full/full_like
reveal_type(torch.full((2, 3), 3.141592))  # E: {Tensor}
reveal_type(torch.full_like(torch.full((2, 3), 3.141592),
                            2.71828))  # E: {Tensor}

# torch.quantize_per_tensor
reveal_type(
    torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10,
                              torch.quint8))  # E: {Tensor}

# torch.quantize_per_channel
x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
quant = torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]),
                                   torch.tensor([10, 0]), 0, torch.quint8)
reveal_type(x)  # E: {Tensor}

# torch.dequantize
reveal_type(torch.dequantize(x))  # E: {Tensor}

# torch.complex
real = torch.tensor([1, 2], dtype=torch.float32)
imag = torch.tensor([3, 4], dtype=torch.float32)
reveal_type(torch.complex(real, imag))  # E: {Tensor}
    def test_quantized_add_relu_fusion(self):
        class MAdd(torch.nn.Module):
            def __init__(self):
                super(MAdd, self).__init__()

            def forward(self, x, y):
                a = torch.ops.quantized.add(x, y, 1., 0)
                relu_out = torch.relu(a)
                return relu_out

        A = torch.arange(-128, 130, dtype=torch.float)
        B = torch.arange(-128, 130, dtype=torch.float)
        scale = 2.0
        zero_point = 127
        qA = torch.quantize_per_tensor(A, scale=scale, zero_point=zero_point,
                                       dtype=torch.quint8)
        qB = torch.quantize_per_tensor(B, scale=scale, zero_point=zero_point,
                                       dtype=torch.quint8)

        # Check quantized add + relu fusion
        m = MAdd()
        scripted_m = torch.jit.script(m)
        ref_output = scripted_m(qA, qB)

        # Must inline the graph.
        # In this test case since we are directly calling ops
        # it does not matter, however if we are calling nn
        # modules we have to inline graph.
        torch._C._jit_pass_inline(scripted_m.graph)
        torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
        FileCheck().check_not("aten::relu") \
                   .check("quantized::add_relu") \
                   .run(scripted_m.graph)
        output = scripted_m(qA, qB)
        self.assertEqual(ref_output, output)

        class MAddOut(torch.nn.Module):
            def __init__(self):
                super(MAddOut, self).__init__()

            def forward(self, x, y, z):
                a = torch.ops.quantized.add_out(x, y, z)
                relu_out = torch.relu(a)
                return relu_out

        qC = torch._empty_affine_quantized(qA.shape,
                                           scale=scale,
                                           zero_point=zero_point,
                                           dtype=torch.quint8)
        # Check quantized add + relu fusion
        m = MAddOut()
        scripted_m = torch.jit.script(m)
        ref_output = scripted_m(qA, qB, qC)
        # Must inline the graph.
        # In this test case since we are directly calling ops
        # it does not matter, however if we are calling nn
        # modules we have to inline graph.
        torch._C._jit_pass_inline(scripted_m.graph)
        torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
        FileCheck().check_not("aten::relu") \
                   .check_not("quantized::add_out") \
                   .check("quantized::add_relu_out") \
                   .run(scripted_m.graph)
        output = scripted_m(qA, qB, qC)
        self.assertEqual(ref_output, output)

        class MAddScalar(torch.nn.Module):
            def __init__(self):
                super(MAddScalar, self).__init__()

            def forward(self, x, y : float):
                a = torch.ops.quantized.add_scalar(x, y)
                relu_out = torch.relu(a)
                return relu_out

        # Check quantized add + relu fusion
        m = MAddScalar()
        scripted_m = torch.jit.script(m)
        ref_output = scripted_m(qA, 3.)
        torch._C._jit_pass_inline(scripted_m.graph)
        torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
        FileCheck().check_not("aten::relu") \
                   .check_not("quantized::add_scalar(") \
                   .check("quantized::add_scalar_relu") \
                   .run(scripted_m.graph)
        output = scripted_m(qA, 3.)
        self.assertEqual(ref_output, output)

        class MAddScalarOut(torch.nn.Module):
            def __init__(self):
                super(MAddScalarOut, self).__init__()

            def forward(self, x, y : float, z):
                a = torch.ops.quantized.add_scalar_out(x, y, z)
                relu_out = torch.relu(a)
                return relu_out

        qC = torch._empty_affine_quantized(qA.shape,
                                           scale=scale,
                                           zero_point=zero_point,
                                           dtype=torch.quint8)
        m = MAddScalarOut()
        scripted_m = torch.jit.script(m)
        ref_output = scripted_m(qA, 3., qC)
        torch._C._jit_pass_inline(scripted_m.graph)
        torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
        FileCheck().check_not("aten::relu") \
                   .check_not("quantized::add_scalar_out") \
                   .check("quantized::add_scalar_relu_out") \
                   .run(scripted_m.graph)
        output = scripted_m(qA, 3., qC)
        self.assertEqual(ref_output, output)
Exemple #17
0
    def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel):
        """test API functionality for nn.quantized.linear and nn.intrinsic.quantized.linear_relu"""
        if torch.backends.quantized.engine == 'qnnpack':
            per_channel = False
        W = torch.rand(out_features, in_features).float()
        if per_channel:
            scale_tensor = torch.ones(out_features, dtype=torch.double)
            zero_point_tensor = torch.zeros(out_features, dtype=torch.long)
            for i in range(len(scale_tensor)):
                scale_tensor[i] = (i + 1.0) / 255.0
            W_q = torch.quantize_per_channel(W, scales=scale_tensor,
                                             zero_points=zero_point_tensor,
                                             axis=0, dtype=torch.qint8)
        else:
            W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8)

        X = torch.rand(batch_size, in_features).float()
        X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8)
        B = torch.rand(out_features).float() if use_bias else None
        scale = 0.5
        zero_point = 3
        if use_fused:
            qlinear = nnq_fused.LinearReLU(in_features, out_features)
        else:
            qlinear = nnq.Linear(in_features, out_features)

        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        qlinear(X_q)

        qlinear.set_weight_bias(W_q, B)
        # Simple round-trip test to ensure weight()/set_weight() API
        self.assertEqual(qlinear.weight(), W_q, atol=1e-5, rtol=0)
        W_pack = qlinear._packed_params._packed_params

        qlinear.scale = float(scale)
        qlinear.zero_point = int(zero_point)
        Z_q = qlinear(X_q)
        # Check if the module implementation matches calling the
        # ops directly
        if use_fused:
            Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point)

            self.assertTrue('QuantizedLinearReLU' in str(qlinear))
        else:
            Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point)

            self.assertTrue('QuantizedLinear' in str(qlinear))
        self.assertEqual(Z_ref, Z_q)

        # Test serialization of quantized Linear Module using state_dict
        model_dict = qlinear.state_dict()
        b = io.BytesIO()
        torch.save(model_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in model_dict:
            if isinstance(model_dict[key], torch._C.ScriptObject):
                assert isinstance(loaded_dict[key], torch._C.ScriptObject)
                w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key])
                w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key])
                self.assertEqual(w_model, w_loaded)
                self.assertEqual(b_model, b_loaded)
            else:
                self.assertEqual(model_dict[key], loaded_dict[key])
        if use_fused:
            loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features)
        else:
            loaded_qlinear = nnq.Linear(in_features, out_features)
        loaded_qlinear.load_state_dict(loaded_dict)

        linear_unpack = torch.ops.quantized.linear_unpack
        self.assertEqual(linear_unpack(qlinear._packed_params._packed_params),
                         linear_unpack(loaded_qlinear._packed_params._packed_params))
        if use_bias:
            self.assertEqual(qlinear.bias(), loaded_qlinear.bias())
        self.assertEqual(qlinear.scale, loaded_qlinear.scale)
        self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
        self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
        self.assertTrue(hasattr(qlinear, '_packed_params'))
        self.assertTrue(hasattr(loaded_qlinear, '_packed_params'))
        self.assertTrue(hasattr(qlinear, '_weight_bias'))
        self.assertTrue(hasattr(loaded_qlinear, '_weight_bias'))
        self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
        self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
        Z_q2 = loaded_qlinear(X_q)
        self.assertEqual(Z_q, Z_q2)

        b = io.BytesIO()
        torch.save(qlinear, b)
        b.seek(0)
        loaded = torch.load(b)
        self.assertEqual(qlinear.weight(), loaded.weight())
        self.assertEqual(qlinear.scale, loaded.scale)
        self.assertEqual(qlinear.zero_point, loaded.zero_point)

        # Test JIT
        self.checkScriptable(qlinear, [[X_q]], check_save_load=True)

        # Test from_float.
        float_linear = torch.nn.Linear(in_features, out_features).float()
        float_linear.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(float_linear, inplace=True)
        float_linear(X.float())
        # Sequential allows swapping using "convert".
        quantized_float_linear = torch.nn.Sequential(float_linear)
        quantized_float_linear = torch.quantization.convert(quantized_float_linear, inplace=True)

        # Smoke test to make sure the module actually runs
        quantized_float_linear(X_q)

        # Smoke test extra_repr
        self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
Exemple #18
0
    def __init__(self, mode, input_size, hidden_size,
                 num_layers=1, bias=True, batch_first=False,
                 dropout=0., bidirectional=False, dtype=torch.qint8):
        super(RNNBase, self).__init__()

        self.mode = mode
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.batch_first = batch_first
        self.dropout = float(dropout)
        self.bidirectional = bidirectional
        self.dtype = dtype
        self.version = 2
        self.training = False
        num_directions = 2 if bidirectional else 1

        # "type: ignore" is required since ints and Numbers are not fully comparable
        # https://github.com/python/mypy/issues/8566
        if not isinstance(dropout, numbers.Number) \
                or not 0 <= dropout <= 1 or isinstance(dropout, bool):  # type: ignore[operator]
            raise ValueError("dropout should be a number in range [0, 1] "
                             "representing the probability of an element being "
                             "zeroed")
        if dropout > 0 and num_layers == 1:  # type: ignore[operator]
            warnings.warn("dropout option adds dropout after all but last "
                          "recurrent layer, so non-zero dropout expects "
                          "num_layers greater than 1, but got dropout={} and "
                          "num_layers={}".format(dropout, num_layers))

        if mode == 'LSTM':
            gate_size = 4 * hidden_size
        elif mode == 'GRU':
            gate_size = 3 * hidden_size
        else:
            raise ValueError("Unrecognized RNN mode: " + mode)

        _all_weight_values = []
        for layer in range(num_layers):
            for direction in range(num_directions):
                layer_input_size = input_size if layer == 0 else hidden_size * num_directions

                w_ih = torch.randn(gate_size, layer_input_size).to(torch.float)
                w_hh = torch.randn(gate_size, hidden_size).to(torch.float)
                b_ih = torch.randn(gate_size).to(torch.float)
                b_hh = torch.randn(gate_size).to(torch.float)
                if dtype == torch.qint8:
                    w_ih = torch.quantize_per_tensor(w_ih, scale=0.1, zero_point=0, dtype=torch.qint8)
                    w_hh = torch.quantize_per_tensor(w_hh, scale=0.1, zero_point=0, dtype=torch.qint8)
                    packed_ih = \
                        torch.ops.quantized.linear_prepack(w_ih, b_ih)
                    packed_hh = \
                        torch.ops.quantized.linear_prepack(w_hh, b_hh)
                    if self.version is None or self.version < 2:
                        cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
                            packed_ih, packed_hh, b_ih, b_hh)
                    else:
                        cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
                            packed_ih, packed_hh, b_ih, b_hh, True)

                else:

                    packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih)
                    packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh)
                    cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
                        packed_ih, packed_hh)

                _all_weight_values.append(PackedParameter(cell_params))
        self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
    def _test_conv_api_impl(
        self,
        module_name,
        qconv_module,
        conv_module,
        batch_size,
        in_channels_per_group,
        input_feature_map_size,
        out_channels_per_group,
        groups,
        kernel_size,
        stride,
        padding,
        dilation,
        X_scale,
        X_zero_point,
        W_scale,
        W_zero_point,
        Y_scale,
        Y_zero_point,
        use_bias,
        use_fused,
        use_channelwise,
    ):
        for i in range(len(kernel_size)):
            assume(input_feature_map_size[i] + 2 * padding[i] >= dilation[i] *
                   (kernel_size[i] - 1) + 1)

        in_channels = in_channels_per_group * groups
        out_channels = out_channels_per_group * groups
        (X, X_q, W, W_q, b) = _make_conv_test_input(
            batch_size, in_channels_per_group, input_feature_map_size,
            out_channels_per_group, groups, kernel_size, X_scale, X_zero_point,
            W_scale, W_zero_point, use_bias, use_channelwise)

        qconv_module.set_weight_bias(W_q, b)
        qconv_module.scale = Y_scale
        qconv_module.zero_point = Y_zero_point

        if use_fused:
            conv_module[0].weight.data = W
            if use_bias:
                conv_module[0].bias.data = b
        else:
            conv_module.weight.data = W
            if use_bias:
                conv_module.bias.data = b

        # Test members
        self.assertTrue(module_name in str(qconv_module))
        self.assertTrue(hasattr(qconv_module, '_packed_params'))
        self.assertTrue(hasattr(qconv_module, 'scale'))
        self.assertTrue(hasattr(qconv_module, 'zero_point'))

        # Test properties
        self.assertEqual(W_q, qconv_module.weight())
        if use_bias:
            self.assertEqual(b, qconv_module.bias())
        self.assertEqual(Y_scale, qconv_module.scale)
        self.assertEqual(Y_zero_point, qconv_module.zero_point)

        # Test forward
        Y_exp = conv_module(X)
        Y_exp = torch.quantize_per_tensor(Y_exp,
                                          scale=Y_scale,
                                          zero_point=Y_zero_point,
                                          dtype=torch.quint8)
        Y_act = qconv_module(X_q)

        # Make sure the results match
        # assert_array_almost_equal compares using the following formula:
        #     abs(desired-actual) < 1.5 * 10**(-decimal)
        # (https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_almost_equal.html)
        # We use decimal = 0 to ignore off-by-1 differences between reference
        # and test. Off-by-1 differences arise due to the order of round and
        # zero_point addition operation, i.e., if addition followed by round is
        # used by reference and round followed by addition is used by test, the
        # results may differ by 1.
        # For example, the result of round(2.5) + 1 is 3 while round(2.5 + 1) is
        # 4 assuming the rounding mode is round-to-nearest, ties-to-even.
        np.testing.assert_array_almost_equal(Y_exp.int_repr().numpy(),
                                             Y_act.int_repr().numpy(),
                                             decimal=0)

        # Test serialization of quantized Conv Module using state_dict
        model_dict = qconv_module.state_dict()
        self.assertEqual(W_q, model_dict['weight'])
        if use_bias:
            self.assertEqual(b, model_dict['bias'])
        bytes_io = io.BytesIO()
        torch.save(model_dict, bytes_io)
        bytes_io.seek(0)
        loaded_dict = torch.load(bytes_io)
        for key in loaded_dict:
            self.assertEqual(model_dict[key], loaded_dict[key])

        loaded_qconv_module = type(qconv_module)(in_channels,
                                                 out_channels,
                                                 kernel_size,
                                                 stride,
                                                 padding,
                                                 dilation,
                                                 groups,
                                                 use_bias,
                                                 padding_mode="zeros")
        loaded_qconv_module.load_state_dict(loaded_dict)

        self.assertTrue(dir(loaded_qconv_module) == dir(qconv_module))
        self.assertTrue(module_name in str(loaded_qconv_module))
        self.assertTrue(hasattr(loaded_qconv_module, '_packed_params'))
        self.assertTrue(hasattr(loaded_qconv_module, '_weight_bias'))

        self.assertEqual(qconv_module.weight(), loaded_qconv_module.weight())
        if use_bias:
            self.assertEqual(qconv_module.bias(), loaded_qconv_module.bias())
        self.assertEqual(qconv_module.scale, loaded_qconv_module.scale)
        self.assertEqual(qconv_module.zero_point,
                         loaded_qconv_module.zero_point)
        Y_loaded = loaded_qconv_module(X_q)
        np.testing.assert_array_almost_equal(Y_exp.int_repr().numpy(),
                                             Y_loaded.int_repr().numpy(),
                                             decimal=0)

        # The below check is meant to ensure that `torch.save` and `torch.load`
        # serialization works, however it is currently broken by the following:
        # https://github.com/pytorch/pytorch/issues/24045
        #
        # Instead, we currently check that the proper exception is thrown on
        # save.
        # <start code>
        # b = io.BytesIO()
        # torch.save(conv_under_test, b)
        # b.seek(0)
        # loaded_conv = torch.load(b)
        #
        # self.assertEqual(loaded_qconv_module.bias(), qconv_module.bias())
        # self.assertEqual(loaded_qconv_module.scale, qconv_module.scale)
        # self.assertEqual(loaded_qconv_module.zero_point,
        #                  qconv_module.zero_point)
        # <end code>
        with self.assertRaisesRegex(
                RuntimeError, r'torch.save\(\) is not currently supported'):
            bytes_io = io.BytesIO()
            torch.save(qconv_module, bytes_io)

        # JIT testing
        self.checkScriptable(qconv_module,
                             list(zip([X_q], [Y_exp])),
                             check_save_load=True)

        # Test from_float
        conv_module.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(conv_module, inplace=True)
        conv_module(X.float())
        converted_qconv_module = torch.nn.Sequential(conv_module)
        torch.quantization.convert(converted_qconv_module, inplace=True)

        # Smoke test to make sure the module actually runs
        if use_bias:
            if use_fused:
                self.assertEqual(conv_module[0].bias,
                                 converted_qconv_module[0].bias())
            else:
                self.assertEqual(conv_module.bias,
                                 converted_qconv_module[0].bias())
        # Smoke test extra_repr
        self.assertTrue(module_name in str(converted_qconv_module))
    def _test_conv_api_impl(
            self, module_name, qconv_module, conv_module, batch_size,
            in_channels_per_group, input_feature_map_size, out_channels_per_group,
            groups, kernel_size, stride, padding, padding_mode, dilation,
            X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
            use_bias, use_fused, use_channelwise):
        for i in range(len(kernel_size)):
            assume(input_feature_map_size[i] + 2 * padding[i]
                   >= dilation[i] * (kernel_size[i] - 1) + 1)

        in_channels = in_channels_per_group * groups
        out_channels = out_channels_per_group * groups
        (X, X_q, W, W_q, b) = _make_conv_test_input(
            batch_size, in_channels_per_group, input_feature_map_size,
            out_channels_per_group, groups, kernel_size, X_scale, X_zero_point,
            W_scale, W_zero_point, use_bias, use_channelwise)
        # Make sure the weight shape is correct
        self.assertTrue(qconv_module.weight().shape == W_q.shape)

        qconv_module.set_weight_bias(W_q, b)
        qconv_module.scale = Y_scale
        qconv_module.zero_point = Y_zero_point

        if use_fused:
            conv_module[0].weight.data = W
            if use_bias:
                conv_module[0].bias.data = b
        else:
            conv_module.weight.data = W
            if use_bias:
                conv_module.bias.data = b

        # Test members
        self.assertTrue(module_name == qconv_module._get_name(), module_name + " " + qconv_module._get_name())
        self.assertTrue(hasattr(qconv_module, '_packed_params'))
        self.assertTrue(hasattr(qconv_module, 'scale'))
        self.assertTrue(hasattr(qconv_module, 'zero_point'))

        # Test properties
        self.assertEqual(W_q, qconv_module.weight())
        if use_bias:
            self.assertEqual(b, qconv_module.bias())
        self.assertEqual(Y_scale, qconv_module.scale)
        self.assertEqual(Y_zero_point, qconv_module.zero_point)

        # Test forward
        Y_exp = conv_module(X)
        Y_exp = torch.quantize_per_tensor(
            Y_exp, scale=Y_scale, zero_point=Y_zero_point, dtype=torch.quint8)
        Y_act = qconv_module(X_q)

        # Make sure the results match
        # assert_array_almost_equal compares using the following formula:
        #     abs(desired-actual) < 1.5 * 10**(-decimal)
        # (https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_almost_equal.html)
        # We use decimal = 0 to ignore off-by-1 differences between reference
        # and test. Off-by-1 differences arise due to the order of round and
        # zero_point addition operation, i.e., if addition followed by round is
        # used by reference and round followed by addition is used by test, the
        # results may differ by 1.
        # For example, the result of round(2.5) + 1 is 3 while round(2.5 + 1) is
        # 4 assuming the rounding mode is round-to-nearest, ties-to-even.
        # skip numerics checking for reference module
        np.testing.assert_array_almost_equal(
            Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0)

        # Test serialization of quantized Conv Module using state_dict
        model_dict = qconv_module.state_dict()
        self.assertEqual(model_dict['weight'], W_q)
        if use_bias:
            self.assertEqual(model_dict['bias'], b)
        bytes_io = io.BytesIO()
        torch.save(model_dict, bytes_io)
        bytes_io.seek(0)
        loaded_dict = torch.load(bytes_io)
        for key in loaded_dict:
            self.assertEqual(model_dict[key], loaded_dict[key])
        loaded_qconv_module = type(qconv_module)(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            groups, use_bias, padding_mode=padding_mode)
        loaded_qconv_module.load_state_dict(loaded_dict)

        self.assertTrue(dir(loaded_qconv_module) == dir(qconv_module))
        self.assertTrue(module_name == loaded_qconv_module._get_name())
        self.assertTrue(hasattr(loaded_qconv_module, '_packed_params'))
        self.assertTrue(hasattr(loaded_qconv_module, '_weight_bias'))

        self.assertEqual(qconv_module.weight(), loaded_qconv_module.weight())
        if use_bias:
            self.assertEqual(qconv_module.bias(), loaded_qconv_module.bias())
        self.assertEqual(qconv_module.scale, loaded_qconv_module.scale)
        self.assertEqual(qconv_module.zero_point,
                         loaded_qconv_module.zero_point)
        Y_loaded = loaded_qconv_module(X_q)
        np.testing.assert_array_almost_equal(
            Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0)

        # Test serialization
        b = io.BytesIO()
        torch.save(qconv_module, b)
        b.seek(0)
        loaded_conv = torch.load(b)

        self.assertEqual(loaded_conv.bias(), qconv_module.bias())
        self.assertEqual(loaded_conv.scale, qconv_module.scale)
        self.assertEqual(loaded_conv.zero_point,
                         qconv_module.zero_point)

        # Test copy and deepcopy
        copied_conv = copy.copy(qconv_module)
        self.assertEqual(copied_conv.bias(), qconv_module.bias())
        self.assertEqual(copied_conv.scale, qconv_module.scale)
        self.assertEqual(copied_conv.zero_point,
                         qconv_module.zero_point)
        Y_copied = copied_conv(X_q)
        np.testing.assert_array_almost_equal(
            Y_exp.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0)

        deepcopied_conv = copy.deepcopy(qconv_module)
        self.assertEqual(deepcopied_conv.bias(), qconv_module.bias())
        self.assertEqual(deepcopied_conv.scale, qconv_module.scale)
        self.assertEqual(deepcopied_conv.zero_point,
                         qconv_module.zero_point)
        Y_deepcopied = copied_conv(X_q)
        np.testing.assert_array_almost_equal(
            Y_exp.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0)

        # JIT testing
        self.checkScriptable(
            qconv_module, [[X_q]],
            check_save_load=True)

        # Test from_float
        fused_conv_module = torch.nn.intrinsic._FusedModule(conv_module)
        fused_conv_module.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(fused_conv_module, inplace=True)
        fused_conv_module(X.float())
        converted_qconv_module = fused_conv_module
        reference_mapping = get_default_static_quant_module_mappings()
        reference_mapping[type(conv_module)] = type(qconv_module)
        torch.quantization.convert(converted_qconv_module, mapping=reference_mapping, inplace=True)

        # Smoke test to make sure the module actually runs
        if use_bias:
            if use_fused:
                self.assertEqual(conv_module[0].bias,
                                 converted_qconv_module[0].bias())
            else:
                self.assertEqual(conv_module.bias,
                                 converted_qconv_module[0].bias())
        # Smoke test extra_repr
        self.assertTrue(module_name == converted_qconv_module[0]._get_name())
def regular_serialization():
    test_cases = {}
    for dtype, device in itertools.product(all_dtypes, all_devices):
        base_name = f'regular_serialization_{dtype_name(dtype)}_{device}'

        test_cases[f'{base_name}_0'] = [
            make_tensor((3, 5), device=device, dtype=dtype, low=-9, high=9)
        ]

        a = make_tensor((15, 5, 5), device=device, dtype=dtype, low=-9, high=9)
        test_cases[f'{base_name}_1'] = [
            get_storage(a),
            a.view((5, 3, 25)),
            a,
            a[1:],
        ]

        if dtype.is_floating_point or dtype.is_complex:
            m = torch.nn.Linear(50, 10, dtype=dtype, device=device)
            test_cases[f'{base_name}_module_0'] = [m]

        # Quantization
        if dtype == torch.float and device == 'cpu':
            for qdtype in [
                    torch.quint8, torch.qint8, torch.qint32, torch.quint4x2
            ]:
                a = make_tensor((10, 3, 8, 2, 4),
                                device=device,
                                dtype=dtype,
                                low=-9,
                                high=9)
                q = torch.quantize_per_tensor(a, 1.0, 2, qdtype)
                test_cases[f'{base_name}_quant_0_{dtype_name(qdtype)}'] = [q]
                test_cases[f'{base_name}_quant_1_{dtype_name(qdtype)}'] = [
                    a, q
                ]

                # TODO: For some reason, qint32 throws an illegal instruction
                # error, for both master and local branch. Either I'm doing
                # something wrong or it's an actual problem. Either way,
                # I should file an issue
                if qdtype == torch.qint32:
                    continue

                a = make_tensor((10, 3, 8, 2, 4),
                                device=device,
                                dtype=dtype,
                                low=-9,
                                high=9)
                scales = make_tensor((8, ),
                                     device=device,
                                     dtype=dtype,
                                     low=-9,
                                     high=9)
                zero_points = make_tensor((8, ),
                                          device=device,
                                          dtype=dtype,
                                          low=-9,
                                          high=9)
                q = torch.quantize_per_channel(a, scales, zero_points, 2,
                                               qdtype)
                test_cases[
                    f'{base_name}_quant_channel_0_{dtype_name(qdtype)}'] = [q]
                test_cases[
                    f'{base_name}_quant_channel_1_{dtype_name(qdtype)}'] = [
                        a, q
                    ]

        # TODO: test sparse COO
        # TODO: test packaging

    return test_cases
    def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel):
        if torch.backends.quantized.engine == 'qnnpack':
            per_channel = False

        # use_fused -> quantized class
        class_map = {
            True: nniq.LinearReLU,
            False: nnq.Linear,
        }

        W = torch.rand(out_features, in_features).float()
        if per_channel:
            scale_tensor = torch.ones(out_features, dtype=torch.double)
            zero_point_tensor = torch.zeros(out_features, dtype=torch.long)
            for i in range(len(scale_tensor)):
                scale_tensor[i] = (i + 1.0) / 255.0
            W_q = torch.quantize_per_channel(W, scales=scale_tensor,
                                             zero_points=zero_point_tensor,
                                             axis=0, dtype=torch.qint8)
        else:
            W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8)

        X = torch.rand(batch_size, in_features).float()
        X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8)
        B = torch.rand(out_features).float() if use_bias else None
        scale = 0.5
        zero_point = 3
        qlinear = class_map[use_fused](in_features, out_features)

        qlinear_copy = copy.deepcopy(qlinear)
        self.checkScriptable(qlinear_copy, [[X_q]], check_save_load=True)
        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        qlinear(X_q)

        qlinear.set_weight_bias(W_q, B)
        # Simple round-trip test to ensure weight()/set_weight() API
        self.assertEqual(qlinear.weight(), W_q, atol=1e-5, rtol=0)

        # testing packed param implementation
        qlinear.scale = float(scale)
        qlinear.zero_point = int(zero_point)
        Z_q = qlinear(X_q)

        # Check if the module implementation matches calling the
        # ops directly
        W_pack = qlinear._packed_params._packed_params
        if use_fused:
            Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point)
        else:
            Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point)

        self.assertEqual(Z_ref, Z_q)
        self.assertTrue(
            ("QuantizedLinearReLU" if use_fused else "QuantizedLinear") in str(qlinear))

        # Test serialization of quantized Linear Module using state_dict
        model_dict = qlinear.state_dict()
        b = io.BytesIO()
        torch.save(model_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in model_dict:
            if isinstance(model_dict[key], torch._C.ScriptObject):
                assert isinstance(loaded_dict[key], torch._C.ScriptObject)
                w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key])
                w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key])
                self.assertEqual(w_model, w_loaded)
                self.assertEqual(b_model, b_loaded)
            else:
                self.assertEqual(model_dict[key], loaded_dict[key])

        loaded_qlinear = class_map[use_fused](
            in_features, out_features)
        loaded_qlinear.load_state_dict(loaded_dict)
        linear_unpack = torch.ops.quantized.linear_unpack
        self.assertEqual(linear_unpack(qlinear._packed_params._packed_params),
                         linear_unpack(loaded_qlinear._packed_params._packed_params))
        self.assertEqual(qlinear.scale, loaded_qlinear.scale)
        self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
        # scripting will add __overloads__ to __dict__, which is why we script a copy
        # to be able to do the check in the next line
        self.checkScriptable(copy.deepcopy(loaded_qlinear), [[X_q]], check_save_load=True)
        self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
        self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
        self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
        Z_q2 = loaded_qlinear(X_q)
        self.assertEqual(Z_q, Z_q2)

        # Test serialization
        b = io.BytesIO()
        torch.save(qlinear, b)
        b.seek(0)
        loaded = torch.load(b)
        self.assertEqual(qlinear.weight(), loaded.weight())
        self.assertEqual(qlinear.scale, loaded.scale)
        self.assertEqual(qlinear.zero_point, loaded.zero_point)

        # Test copy and deepcopy
        copied_linear = copy.copy(qlinear)
        self.assertEqual(copied_linear.bias(), qlinear.bias())
        self.assertEqual(copied_linear.scale, qlinear.scale)
        self.assertEqual(copied_linear.zero_point,
                         qlinear.zero_point)
        Y_copied = copied_linear(X_q)
        np.testing.assert_array_almost_equal(
            Z_q.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0)

        deepcopied_linear = copy.deepcopy(qlinear)
        self.assertEqual(deepcopied_linear.bias(), qlinear.bias())
        self.assertEqual(deepcopied_linear.scale, qlinear.scale)
        self.assertEqual(deepcopied_linear.zero_point,
                         qlinear.zero_point)
        Y_deepcopied = copied_linear(X_q)
        np.testing.assert_array_almost_equal(
            Z_q.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0)

        # Test JIT
        self.checkScriptable(qlinear, [[X_q]], check_save_load=True)

        # Make sure `from_float` works for all linear variants
        modules_under_test = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]

        for mut in modules_under_test:
            # Test from_float.
            float_linear = mut(in_features, out_features).float()
            float_linear.qconfig = torch.quantization.default_qconfig
            torch.quantization.prepare(float_linear, inplace=True)
            float_linear(X.float())
            # Sequential allows swapping using "convert".
            quantized_float_linear = torch.nn.Sequential(float_linear)
            quantized_float_linear = torch.quantization.convert(quantized_float_linear, inplace=True)

            # Smoke test to make sure the module actually runs
            quantized_float_linear(X_q)

            # Smoke test extra_repr
            self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
    def test_linear_api(self, batch_size, in_features, out_features, use_bias,
                        use_default_observer):
        """test API functionality for nn.quantized.dynamic.Linear"""
        W = torch.rand(out_features, in_features).float()
        W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8)
        W_q = torch.quantize_per_tensor(W, W_scale, W_zp, torch.qint8)
        X = torch.rand(batch_size, in_features).float()
        B = torch.rand(out_features).float() if use_bias else None
        qlinear = nnqd.Linear(in_features, out_features)
        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        qlinear.set_weight_bias(W_q, B)
        qlinear(X)

        # Simple round-trip test to ensure weight()/set_weight() API
        self.assertEqual(qlinear.weight(), W_q)
        W_pack = qlinear._packed_params
        Z_dq = qlinear(X)

        # Check if the module implementation matches calling the
        # ops directly
        Z_ref = torch.ops.quantized.linear_dynamic(X, W_pack)
        self.assertEqual(Z_ref, Z_dq)

        # Test serialization of dynamic quantized Linear Module using state_dict
        model_dict = qlinear.state_dict()
        self.assertEqual(model_dict['weight'], W_q)
        if use_bias:
            self.assertEqual(model_dict['bias'], B)
        b = io.BytesIO()
        torch.save(model_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in model_dict:
            self.assertEqual(model_dict[key], loaded_dict[key])
        loaded_qlinear = nnqd.Linear(in_features, out_features)
        loaded_qlinear.load_state_dict(loaded_dict)

        linear_unpack = torch.ops.quantized.linear_unpack
        self.assertEqual(linear_unpack(qlinear._packed_params),
                         linear_unpack(loaded_qlinear._packed_params))
        if use_bias:
            self.assertEqual(qlinear.bias(), loaded_qlinear.bias())
        self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
        self.assertTrue(hasattr(qlinear, '_packed_params'))
        self.assertTrue(hasattr(loaded_qlinear, '_packed_params'))
        self.assertTrue(hasattr(qlinear, '_weight_bias'))
        self.assertTrue(hasattr(loaded_qlinear, '_weight_bias'))

        self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
        self.assertEqual(
            qlinear._weight_bias(),
            torch.ops.quantized.linear_unpack(qlinear._packed_params))
        Z_dq2 = qlinear(X)
        self.assertEqual(Z_dq, Z_dq2)

        # The below check is meant to ensure that `torch.save` and `torch.load`
        # serialization works, however it is currently broken by the following:
        # https://github.com/pytorch/pytorch/issues/24045
        #
        # Instead, we currently check that the proper exception is thrown on save.
        # <start code>
        # b = io.BytesIO()
        # torch.save(qlinear, b)
        # b.seek(0)
        # loaded = torch.load(b)
        # self.assertEqual(qlinear.weight(), loaded.weight())
        # self.assertEqual(qlinear.zero_point, loaded.zero_point)
        # <end code>
        with self.assertRaisesRegex(
                RuntimeError, r'torch.save\(\) is not currently supported'):
            b = io.BytesIO()
            torch.save(qlinear, b)

        # Test JIT
        self.checkScriptable(qlinear,
                             list(zip([X], [Z_ref])),
                             check_save_load=True)

        # Test from_float
        float_linear = torch.nn.Linear(in_features, out_features).float()
        if use_default_observer:
            float_linear.qconfig = torch.quantization.default_dynamic_qconfig
        prepare_dynamic(float_linear)
        float_linear(X.float())
        quantized_float_linear = nnqd.Linear.from_float(float_linear)

        # Smoke test to make sure the module actually runs
        quantized_float_linear(X)

        # Smoke test extra_repr
        str(quantized_float_linear)
def qpt(t, scale, zero_point, dtype=torch.quint8):
    t = torch.tensor(t)
    return torch.quantize_per_tensor(t, scale, zero_point, dtype)
    def test_conv_api(self, use_bias, use_fused):
        """Tests the correctness of the conv module.

        The correctness is defined against the functional implementation.
        """

        N, iC, H, W = 10, 10, 10, 3
        oC, g, kH, kW = 16, 1, 3, 3
        scale, zero_point = 1.0 / 255, 128

        X = torch.randn(N, iC, H, W, dtype=torch.float32)
        qX = torch.quantize_per_tensor(X,
                                       scale=scale,
                                       zero_point=128,
                                       dtype=torch.quint8)

        w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32)

        qw = torch.quantize_per_tensor(w,
                                       scale=scale,
                                       zero_point=0,
                                       dtype=torch.qint8)

        b = torch.randn(oC, dtype=torch.float32) if use_bias else None

        if use_fused:
            conv_under_test = ConvReLU2d(in_channels=iC,
                                         out_channels=oC,
                                         kernel_size=(kH, kW),
                                         stride=1,
                                         padding=0,
                                         dilation=1,
                                         groups=g,
                                         bias=use_bias,
                                         padding_mode='zeros')
        else:
            conv_under_test = Conv2d(in_channels=iC,
                                     out_channels=oC,
                                     kernel_size=(kH, kW),
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     bias=use_bias,
                                     padding_mode='zeros')
        # Run module with default-initialized parameters.
        # This tests that the constructor is correct.
        conv_under_test.set_weight_bias(qw, b)
        conv_under_test(qX)

        conv_under_test.scale = scale
        conv_under_test.zero_point = zero_point

        # Test members
        self.assertTrue(hasattr(conv_under_test, '_packed_params'))
        self.assertTrue(hasattr(conv_under_test, 'scale'))
        self.assertTrue(hasattr(conv_under_test, 'zero_point'))

        # Test properties
        self.assertEqual(qw, conv_under_test.weight())
        self.assertEqual(b, conv_under_test.bias())
        self.assertEqual(scale, conv_under_test.scale)
        self.assertEqual(zero_point, conv_under_test.zero_point)

        # Test forward
        result_under_test = conv_under_test(qX)
        result_reference = qF.conv2d(qX,
                                     qw,
                                     bias=b,
                                     scale=scale,
                                     zero_point=zero_point,
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     dtype=torch.quint8)
        if use_fused:
            # result_reference < zero_point doesn't work for qtensor yet
            # result_reference[result_reference < zero_point] = zero_point
            MB, OC, OH, OW = result_reference.size()
            for i in range(MB):
                for j in range(OC):
                    for h in range(OH):
                        for w in range(OW):
                            if result_reference[i][j][h][w].int_repr(
                            ) < zero_point:
                                # assign 0. that gets converted to zero_point
                                result_reference[i][j][h][w] = 0.

        self.assertEqual(result_reference,
                         result_under_test,
                         message="Tensors are not equal.")

        # Test serialization of quantized Conv Module using state_dict
        model_dict = conv_under_test.state_dict()
        self.assertEqual(model_dict['weight'], qw)
        if use_bias:
            self.assertEqual(model_dict['bias'], b)
        b = io.BytesIO()
        torch.save(model_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in model_dict:
            self.assertEqual(loaded_dict[key], model_dict[key])
        if use_fused:
            loaded_conv_under_test = ConvReLU2d(in_channels=iC,
                                                out_channels=oC,
                                                kernel_size=(kH, kW),
                                                stride=1,
                                                padding=0,
                                                dilation=1,
                                                groups=g,
                                                bias=use_bias,
                                                padding_mode='zeros')
        else:
            loaded_conv_under_test = Conv2d(in_channels=iC,
                                            out_channels=oC,
                                            kernel_size=(kH, kW),
                                            stride=1,
                                            padding=0,
                                            dilation=1,
                                            groups=g,
                                            bias=use_bias,
                                            padding_mode='zeros')
        loaded_conv_under_test.load_state_dict(loaded_dict)
        self.assertEqual(loaded_conv_under_test._weight_bias(),
                         conv_under_test._weight_bias())
        if use_bias:
            self.assertEqual(loaded_conv_under_test.bias(),
                             conv_under_test.bias())
        self.assertEqual(loaded_conv_under_test.scale, conv_under_test.scale)
        self.assertEqual(loaded_conv_under_test.zero_point,
                         conv_under_test.zero_point)
        self.assertTrue(dir(loaded_conv_under_test) == dir(conv_under_test))
        self.assertTrue(hasattr(conv_under_test, '_packed_params'))
        self.assertTrue(hasattr(loaded_conv_under_test, '_packed_params'))
        self.assertTrue(hasattr(conv_under_test, '_weight_bias'))
        self.assertTrue(hasattr(loaded_conv_under_test, '_weight_bias'))
        self.assertEqual(loaded_conv_under_test._weight_bias(),
                         conv_under_test._weight_bias())
        self.assertEqual(loaded_conv_under_test.weight(), qw)
        loaded_result = loaded_conv_under_test(qX)
        self.assertEqual(loaded_result, result_reference)

        # The below check is meant to ensure that `torch.save` and `torch.load`
        # serialization works, however it is currently broken by the following:
        # https://github.com/pytorch/pytorch/issues/24045
        #
        # Instead, we currently check that the proper exception is thrown on save.
        # <start code>
        # b = io.BytesIO()
        # torch.save(conv_under_test, b)
        # b.seek(0)
        # loaded_conv = torch.load(b)
        #
        # self.assertEqual(conv_under_test.bias(), loaded_conv.bias())
        # self.assertEqual(conv_under_test.scale, loaded_conv.scale)
        # self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point)
        # <end code>
        with self.assertRaisesRegex(
                RuntimeError, r'torch.save\(\) is not currently supported'):
            b = io.BytesIO()
            torch.save(conv_under_test, b)

        # JIT testing
        self.checkScriptable(conv_under_test,
                             list(zip([qX], [result_reference])),
                             check_save_load=True)

        # Test from_float
        float_conv = torch.nn.Conv2d(in_channels=iC,
                                     out_channels=oC,
                                     kernel_size=(kH, kW),
                                     stride=1,
                                     padding=0,
                                     dilation=1,
                                     groups=g,
                                     bias=use_bias,
                                     padding_mode='zeros').float()
        float_conv.qconfig = torch.quantization.default_qconfig
        torch.quantization.prepare(float_conv, inplace=True)
        float_conv(X.float())
        quantized_float_conv = torch.nn.Sequential(float_conv)
        torch.quantization.convert(quantized_float_conv, inplace=True)

        # Smoke test to make sure the module actually runs
        quantized_float_conv(qX)
        if use_bias:
            self.assertEqual(quantized_float_conv[0].bias(), float_conv.bias)
        # Smoke test extra_repr
        str(quantized_float_conv)
Exemple #26
0
    def test_single_tensors(self):
        class SingleTensorModel(torch.nn.Module):
            def forward(self, arg):
                return arg

        sm = torch.jit.script(SingleTensorModel())
        original_size = model_size(sm)
        get_expr: List[str] = []
        samples = [
            # Tensor with small numel and small storage.
            (
                torch.tensor([1]), ),
            # Tensor with large numel and small storage.
            (
                torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2], ),
            # Tensor with small numel and large storage.
            (
                torch.tensor(range(1 << 16))[-8:], ),
            # Large zero tensor.
            (
                torch.zeros(1 << 16), ),
            # Large channels-last ones tensor.
            (
                torch.ones(4, 8, 32,
                           32).contiguous(memory_format=torch.channels_last),
            ),
            # Special encoding of random tensor.
            (
                torch.utils.bundled_inputs.bundle_randn(1 << 16), ),
            # Quantized uniform tensor.
            (
                torch.quantize_per_tensor(torch.zeros(4, 8, 32, 32), 1, 0,
                                          torch.qint8), ),
        ]
        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
            sm, samples, get_expr)
        # print(get_expr[0])
        # print(sm._generate_bundled_inputs.code)

        # Make sure the model only grew a little bit,
        # despite having nominally large bundled inputs.
        augmented_size = model_size(sm)
        self.assertLess(augmented_size, original_size + (1 << 12))

        loaded = save_and_load(sm)
        inflated = loaded.get_all_bundled_inputs()
        self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
        self.assertEqual(len(inflated), len(samples))
        self.assertTrue(loaded.run_on_bundled_input(0) is inflated[0][0])

        for idx, inp in enumerate(inflated):
            self.assertIsInstance(inp, tuple)
            self.assertEqual(len(inp), 1)
            self.assertIsInstance(inp[0], torch.Tensor)
            if idx != 5:
                # Strides might be important for benchmarking.
                self.assertEqual(inp[0].stride(), samples[idx][0].stride())
                self.assertEqual(inp[0], samples[idx][0], exact_dtype=True)

        # This tensor is random, but with 100,000 trials,
        # mean and std had ranges of (-0.0154, 0.0144) and (0.9907, 1.0105).
        self.assertEqual(inflated[5][0].shape, (1 << 16, ))
        self.assertEqual(inflated[5][0].mean().item(), 0, atol=0.025, rtol=0)
        self.assertEqual(inflated[5][0].std().item(), 1, atol=0.02, rtol=0)
Exemple #27
0
import torch

x = torch.rand(2, 3, dtype=torch.float32)

# x
# tensor([[0.6839, 0.4741, 0.7451],
#         [0.9301, 0.1742, 0.6835]])

print(x)

xq = torch.quantize_per_tensor(x, scale=0.5, zero_point=8, dtype=torch.quint8)
# tensor([[0.5000, 0.5000, 0.5000],
#         [1.0000, 0.0000, 0.5000]], size=(2, 3), dtype=torch.quint8,
#        quantization_scheme=torch.per_tensor_affine, scale=0.5, zero_point=8)

print(xq)

print(xq.int_repr())
# tensor([[ 9,  9,  9],
#         [10,  8,  9]], dtype=torch.uint8)
Exemple #28
0
def pack_weights_for_functionals(
    module: torch.nn.Module,
) -> None:
    """
    Packs weights for functionals seen while tracing.
    Note: weight packing for modules is handled by eager mode quantization
    flow.
    """
    if hasattr(module, '_auto_quant_state'):
        qstate: AutoQuantizationState = module._auto_quant_state  # type: ignore[assignment]
        # find any ops which need packing
        for idx, seen_q_op_info in qstate.idx_to_seen_q_op_infos.items():
            packable_args_len = len(seen_q_op_info.packable_tensor_idx_to_name) + \
                len(seen_q_op_info.packable_nontensor_idx_to_arg)
            if packable_args_len == 0:
                continue

            if seen_q_op_info.type in conv_ops:
                # fetch all the info needed for packed params
                assert seen_q_op_info.packable_tensor_idx_to_name[1] is not None
                weight = getattr(module, seen_q_op_info.packable_tensor_idx_to_name[1])
                assert seen_q_op_info.packable_tensor_idx_to_name[2] is not None
                bias = getattr(module, seen_q_op_info.packable_tensor_idx_to_name[2])
                stride = seen_q_op_info.packable_nontensor_idx_to_arg[3]
                padding = seen_q_op_info.packable_nontensor_idx_to_arg[4]
                dilation = seen_q_op_info.packable_nontensor_idx_to_arg[5]
                groups = seen_q_op_info.packable_nontensor_idx_to_arg[6]

                # quantize the weight
                # TODO: create weight observers from qconfig.weight
                assert seen_q_op_info.input_tensor_infos[1] is not None
                weight_tensor_id = seen_q_op_info.input_tensor_infos[1].id
                weight_obs = qstate.tensor_id_to_observer[str(weight_tensor_id)]
                assert isinstance(weight_obs, (ObserverBase, FakeQuantizeBase))
                scale, zp = weight_obs.calculate_qparams()
                qweight = torch.quantize_per_tensor(weight, scale, zp, torch.qint8)

                # create the packed params
                packed_params = conv_prepack_fns[seen_q_op_info.type](
                    qweight, bias, stride, padding, dilation, groups)

                # attach to module
                name_idx = 0
                prefix = "_packed_params_"
                name_candidate = f"{prefix}{name_idx}"
                while hasattr(module, name_candidate):
                    name_idx += 1
                    name_candidate = f"{prefix}{name_idx}"
                setattr(module, name_candidate, packed_params)
                qstate.idx_to_packed_weight_name[idx] = name_candidate
                # TODO: delete the original weights

            elif seen_q_op_info.type == F.linear:
                # fetch all the info needed for packed params
                def get_tensor_param_name(idx: int, name: str) -> Optional[str]:
                    param_name = seen_q_op_info.packable_tensor_idx_to_name.get(idx, None)
                    if param_name is not None:
                        return param_name
                    return seen_q_op_info.packable_tensor_kwarg_name_to_name.get(name, None)

                weight_name = get_tensor_param_name(1, 'weight')
                assert weight_name is not None
                weight = getattr(module, weight_name)

                bias_name = get_tensor_param_name(2, 'bias')
                bias = getattr(module, bias_name) if bias_name is not None else None

                # quantize the weight
                # TODO: create weight observers from qconfig.weight
                assert seen_q_op_info.input_tensor_infos[1] is not None
                weight_tensor_id = seen_q_op_info.input_tensor_infos[1].id
                weight_obs = qstate.tensor_id_to_observer[str(weight_tensor_id)]
                assert isinstance(weight_obs, (ObserverBase, FakeQuantizeBase))
                scale, zp = weight_obs.calculate_qparams()
                qweight = torch.quantize_per_tensor(weight, scale, zp, torch.qint8)

                # create the packed params
                packed_params = toq.linear_prepack(qweight, bias)

                # attach to module
                name_idx = 0
                prefix = "_packed_params_"
                name_candidate = f"{prefix}{name_idx}"
                while hasattr(module, name_candidate):
                    name_idx += 1
                    name_candidate = f"{prefix}{name_idx}"
                setattr(module, name_candidate, packed_params)
                qstate.idx_to_packed_weight_name[idx] = name_candidate
                # TODO: delete the original weights

    for _, child in module.named_children():
        pack_weights_for_functionals(child)
Exemple #29
0
def _make_conv_test_input(
    batch_size,
    in_channels_per_group,
    input_feature_map_size,
    out_channels_per_group,
    groups,
    kernel_size,
    X_scale,
    X_zero_point,
    W_scale,
    W_zero_point,
    use_bias,
    use_channelwise,
):
    in_channels = in_channels_per_group * groups
    out_channels = out_channels_per_group * groups

    (X_value_min, X_value_max) = (0, 4)
    X_init = torch.randint(X_value_min, X_value_max, (
        batch_size,
        in_channels,
    ) + input_feature_map_size)
    X = X_scale * (X_init - X_zero_point).float()
    X_q = torch.quantize_per_tensor(X,
                                    scale=X_scale,
                                    zero_point=X_zero_point,
                                    dtype=torch.quint8)

    W_scale = W_scale * out_channels
    W_zero_point = W_zero_point * out_channels
    # Resize W_scale and W_zero_points arrays equal to out_channels
    W_scale = W_scale[:out_channels]
    W_zero_point = W_zero_point[:out_channels]
    # For testing, we use small values for weights and for activations so that
    # no overflow occurs in vpmaddubsw instruction. If the overflow occurs in
    # qconv implementation and if there is no overflow.
    # In reference we can't exactly match the results with reference.
    # Please see the comment in qconv implementation file
    #   aten/src/ATen/native/quantized/cpu/qconv.cpp for more details.
    (W_value_min, W_value_max) = (-5, 5)
    # The operator expects them in the format
    # (out_channels, in_channels/groups,) + kernel_size
    W_init = torch.randint(W_value_min, W_value_max, (
        out_channels,
        in_channels_per_group,
    ) + kernel_size)
    b_init = torch.randint(0, 10, (out_channels, ))

    if use_channelwise:
        W_shape = (-1, 1) + (1, ) * len(kernel_size)
        W_scales_tensor = torch.tensor(W_scale, dtype=torch.float)
        W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float)
        W = W_scales_tensor.reshape(*W_shape) * (
            W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float()
        b = X_scale * W_scales_tensor * b_init.float()
        W_q = torch.quantize_per_channel(W,
                                         W_scales_tensor,
                                         W_zero_points_tensor.long(),
                                         0,
                                         dtype=torch.qint8)
    else:
        W = W_scale[0] * (W_init - W_zero_point[0]).float()
        b = X_scale * W_scale[0] * b_init.float()
        W_q = torch.quantize_per_tensor(W,
                                        scale=W_scale[0],
                                        zero_point=W_zero_point[0],
                                        dtype=torch.qint8)

    return (X, X_q, W, W_q, b if use_bias else None)
Exemple #30
0
 def forward(self, X):
     return torch.quantize_per_tensor(X, float(self.scale),
                                      int(self.zero_point), self.dtype)