예제 #1
0
    def test_per_tensor_dynamic_quant_observers(self, X, reduce_range):

        X, (scale, zero_point, torch_type) = X
        x = torch.from_numpy(X)

        obs = MinMaxDynamicQuantObserver(dtype=torch.quint8, reduce_range=reduce_range)

        result = obs(x)
        qparams = obs.calculate_qparams()
        ref = torch._choose_qparams_per_tensor(x, reduce_range)

        self.assertEqual(ref[0], qparams[0])
        self.assertEqual(ref[1], qparams[1])
예제 #2
0
    def test_observer_scriptable(self):
        obs_list = [
            MinMaxObserver(),
            MovingAverageMinMaxObserver(),
            MinMaxDynamicQuantObserver()
        ]
        for obs in obs_list:
            scripted = torch.jit.script(obs)

            x = torch.rand(3, 4)
            obs(x)
            scripted(x)
            self.assertEqual(obs.calculate_qparams(),
                             scripted.calculate_qparams())

            buf = io.BytesIO()
            torch.jit.save(scripted, buf)
            buf.seek(0)
            loaded = torch.jit.load(buf)
            self.assertEqual(obs.calculate_qparams(),
                             loaded.calculate_qparams())

        # Check TensorListObserver
        from torch.quantization.observer import _MinMaxTensorListObserver
        obs = _MinMaxTensorListObserver()
        scripted = torch.jit.script(obs)
        x = [torch.rand(3, 4), torch.rand(4, 5)]
        obs(x)
        scripted(x)
        self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams())
예제 #3
0
    def test_observer_scriptable(self):
        obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver(), MinMaxDynamicQuantObserver()]
        for obs in obs_list:
            scripted = torch.jit.script(obs)

            x = torch.rand(3, 4)
            obs(x)
            scripted(x)
            self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams())

            buf = io.BytesIO()
            torch.jit.save(scripted, buf)
            buf.seek(0)
            loaded = torch.jit.load(buf)
            self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams())