Exemple #1
0
    def init(self, dims, permute_dims, inplace, dtype):
        self.qop = nnq.ReLU(inplace=inplace)

        # Input dimensions
        f_input = (torch.rand(*dims) - 0.5) * 1e6

        # Get quantization paramerters and quantize
        if dtype in (torch.qint8, torch.quint8):
            observer = tq.MinMaxObserver(dtype=dtype,
                                         qscheme=torch.per_tensor_affine,
                                         reduce_range=False)
            observer.forward(f_input)
            scale, zero_point = observer.calculate_qparams()
            scale, zero_point = scale.item(), zero_point.item()
        else:
            zero_point = 0
            qinfo = torch.iinfo(dtype)
            fmin, fmax = f_input.min().item(), f_input.max().item()
            if fmax == fmin:
                scale = 1.0
            else:
                scale = (fmax - fmin) / (qinfo.max - qinfo.min)

        # Quantize the tensor
        self.q_input = torch.quantize_per_tensor(f_input, scale=scale,
                                                 zero_point=zero_point,
                                                 dtype=dtype)
        if permute_dims:
            # Make non-contiguous
            new_shape = list(range(len(self.q_input.shape)))
            random.shuffle(new_shape)
            self.q_input = self.q_input.permute(new_shape)

        self.set_module_name("QReLU")
    def test_relu(self):
        relu_module = nnq.ReLU()
        relu6_module = nnq.ReLU6()

        x = torch.arange(-10, 10, dtype=torch.float)
        y_ref = torch.relu(x)
        y6_ref = torch.nn.modules.ReLU6()(x)

        qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.qint32)
        qy = relu_module(qx)
        qy6 = relu6_module(qx)

        self.assertEqual(y_ref, qy.dequantize(),
                         msg="ReLU module API failed")
        self.assertEqual(y6_ref, qy6.dequantize(),
                         msg="ReLU6 module API failed")