コード例 #1
0
    def test_observer_scriptable(self):
        obs_list = [
            MinMaxObserver(),
            MovingAverageMinMaxObserver(),
            MinMaxDynamicQuantObserver()
        ]
        for obs in obs_list:
            scripted = torch.jit.script(obs)

            x = torch.rand(3, 4)
            obs(x)
            scripted(x)
            self.assertEqual(obs.calculate_qparams(),
                             scripted.calculate_qparams())

            buf = io.BytesIO()
            torch.jit.save(scripted, buf)
            buf.seek(0)
            loaded = torch.jit.load(buf)
            self.assertEqual(obs.calculate_qparams(),
                             loaded.calculate_qparams())

        # Check TensorListObserver
        from torch.quantization.observer import _MinMaxTensorListObserver
        obs = _MinMaxTensorListObserver()
        scripted = torch.jit.script(obs)
        x = [torch.rand(3, 4), torch.rand(4, 5)]
        obs(x)
        scripted(x)
        self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams())
コード例 #2
0
 def test_tensor_list_observer(self):
     from torch.quantization.observer import _MinMaxTensorListObserver
     x = [torch.tensor([1.0, 2.5, 3.5]),
          torch.tensor([2.0, 4.5, 3.5]),
          torch.tensor([4.0, 2.5, 3.5]), ]
     obs = _MinMaxTensorListObserver()
     obs(x)
     qparams = obs.calculate_qparams()
     ref_min_val = []
     ref_max_val = []
     ref_qparams = []
     for i in x:
         obs_ref = MinMaxObserver()
         obs_ref(i)
         ref_min_val.append(obs_ref.min_val)
         ref_max_val.append(obs_ref.max_val)
         ref_qparams.append(obs_ref.calculate_qparams())
     for i in range(len(x)):
         self.assertEqual(obs.min_val[i], ref_min_val[i])
         self.assertEqual(obs.max_val[i], ref_max_val[i])
         self.assertEqual(qparams[0][i], ref_qparams[i][0])
         self.assertEqual(qparams[1][i], ref_qparams[i][1])