def test_histogram_observer(self, qdtype, qscheme, reduce_range):
        myobs = HistogramObserver(bins=3,
                                  dtype=qdtype,
                                  qscheme=qscheme,
                                  reduce_range=reduce_range)
        # Calculate qparams should work for empty observers
        qparams = myobs.calculate_qparams()
        x = torch.tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True)
        y = torch.tensor([5.0, 6.0, 7.0, 8.0])
        out_x = myobs(x)
        self.assertTrue(out_x.requires_grad)
        myobs(y)
        self.assertEqual(myobs.min_val, 2.0)
        self.assertEqual(myobs.max_val, 8.0)
        self.assertEqual(myobs.histogram, [2., 3., 3.])

        qparams = myobs.calculate_qparams()

        if reduce_range:
            if qscheme == torch.per_tensor_symmetric:
                ref_scale = 0.0470588 * 255 / 127
                ref_zero_point = 0 if qdtype is torch.qint8 else 128
            else:
                ref_scale = 0.0235294 * 255 / 127
                ref_zero_point = -64 if qdtype is torch.qint8 else 0
        else:
            if qscheme == torch.per_tensor_symmetric:
                ref_scale = 0.0470588
                ref_zero_point = 0 if qdtype is torch.qint8 else 128
            else:
                ref_scale = 0.0235294
                ref_zero_point = -128 if qdtype is torch.qint8 else 0

        self.assertEqual(qparams[1].item(), ref_zero_point)
        self.assertEqual(qparams[0].item(), ref_scale, atol=1e-5, rtol=0)
        # Test for serializability
        state_dict = myobs.state_dict()
        b = io.BytesIO()
        torch.save(state_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in state_dict:
            self.assertEqual(state_dict[key], loaded_dict[key])
        loaded_obs = HistogramObserver(bins=3,
                                       dtype=qdtype,
                                       qscheme=qscheme,
                                       reduce_range=reduce_range)
        loaded_obs.load_state_dict(loaded_dict)
        loaded_qparams = loaded_obs.calculate_qparams()
        self.assertEqual(myobs.min_val, loaded_obs.min_val)
        self.assertEqual(myobs.max_val, loaded_obs.max_val)
        self.assertEqual(myobs.histogram, loaded_obs.histogram)
        self.assertEqual(myobs.bins, loaded_obs.bins)
        self.assertEqual(myobs.calculate_qparams(),
                         loaded_obs.calculate_qparams())
示例#2
0
    def test_histogram_observer(self, qdtype, qscheme, reduce_range):
        myobs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
        x = torch.tensor([2.0, 3.0, 4.0, 5.0])
        y = torch.tensor([5.0, 6.0, 7.0, 8.0])
        myobs(x)
        myobs(y)
        self.assertEqual(myobs.min_val, 2.0)
        self.assertEqual(myobs.max_val, 8.0)
        self.assertEqual(myobs.histogram, [2., 3., 3.])

        qparams = myobs.calculate_qparams()

        if reduce_range:
            if qscheme == torch.per_tensor_symmetric:
                ref_scale = 0.0470588 * 255 / 127
                ref_zero_point = 0 if qdtype is torch.qint8 else 128
            else:
                ref_scale = 0.0235294 * 255 / 127
                ref_zero_point = -64 if qdtype is torch.qint8 else 0
        else:
            if qscheme == torch.per_tensor_symmetric:
                ref_scale = 0.0470588
                ref_zero_point = 0 if qdtype is torch.qint8 else 128
            else:
                ref_scale = 0.0235294
                ref_zero_point = -128 if qdtype is torch.qint8 else 0

        self.assertEqual(qparams[1].item(), ref_zero_point)
        self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
示例#3
0
 def test_histogram_observer_one_sided(self):
     myobs = HistogramObserver(bins=8, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
     x = torch.tensor([0.0, 0.3, 1.2, 1.7])
     y = torch.tensor([0.1, 1.3, 2.0, 2.7])
     myobs(x)
     myobs(y)
     self.assertEqual(myobs.min_val, 0)
     qparams = myobs.calculate_qparams()
     self.assertEqual(qparams[1].item(), 0)
示例#4
0
    def test_histogram_observer_against_reference(self, N, bins, dtype, qscheme, reduce_range):

        ref_obs = _ReferenceHistogramObserver(bins=bins, dtype=dtype, qscheme=qscheme, reduce_range=reduce_range)
        my_obs = HistogramObserver(bins=bins, dtype=dtype, qscheme=qscheme, reduce_range=reduce_range)

        for _ in range(10):
            X = torch.randn(N)
            my_obs(X)
            ref_obs(X)

        ref_qparams = ref_obs.calculate_qparams()
        my_qparams = my_obs.calculate_qparams()

        self.assertEqual(ref_qparams, my_qparams)
示例#5
0
 def test_histogram_observer_same_inputs(self):
     myobs = HistogramObserver(bins=3, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
     w = torch.ones(4, requires_grad=True)
     x = torch.zeros(4, requires_grad=True)
     y = torch.tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True)
     z = torch.tensor([5.0, 6.0, 7.0, 8.0])
     myobs(w)
     myobs(x)
     myobs(x)
     myobs(y)
     myobs(z)
     qparams = myobs.calculate_qparams()
     self.assertEqual(myobs.min_val, 2.0)
     self.assertEqual(myobs.max_val, 8.0)
     self.assertEqual(myobs.histogram, [2., 3., 3.])