示例#1
0
    def test_simple_run(self):
        max_calibrator = calib.MaxCalibrator(8, None, False)

        x_1 = torch.rand(129).cuda()
        x_2 = torch.rand(127).cuda()
        max_calibrator.collect(x_1)
        max_calibrator.collect(x_2)

        test_utils.compare(max_calibrator.compute_amax(), torch.max(x_1.max(), x_2.max()), atol=0, rtol=0, ctol=0)

        # Nothing to test other than creation
        max_calibrator = calib.MaxCalibrator(8, None, True)
示例#2
0
    def test_raises(self):
        axis = 0
        max_calibrator = calib.MaxCalibrator(8, axis, False)

        x_2 = torch.rand(32, 63, 7, 7).cuda()
        x_3 = torch.rand(33, 63, 7, 7).cuda()
        max_calibrator.collect(x_2)
        with pytest.raises(RuntimeError, match="shape changed"):
            max_calibrator.collect(x_3)
示例#3
0
    def test_track_amax(self):
        max_calibrator = calib.MaxCalibrator(8, None, False, track_amax=True)

        x_1 = torch.rand(129).cuda()
        x_2 = torch.rand(127).cuda()
        max_calibrator.collect(x_1)
        max_calibrator.collect(x_2)

        test_utils.compare(max_calibrator.compute_amax(), torch.max(x_1.max(), x_2.max()), atol=0, rtol=0, ctol=0)
        np.testing.assert_array_equal(max_calibrator.amaxs[0], x_1.max().cpu().numpy())
        np.testing.assert_array_equal(max_calibrator.amaxs[1], x_2.max().cpu().numpy())
示例#4
0
    def __init__(self,
                 quant_desc=QuantDescriptor(),
                 disabled=False,
                 if_quant=True,
                 if_clip=False,
                 if_calib=False):
        """Initialize quantizer and set up required variables"""
        super(TensorQuantizer, self).__init__()
        # Expand quant_desc. Use quant_desc.dict would be eaiser, but adding one-by-one explicitly gives more control
        self._num_bits = quant_desc.num_bits
        self._fake_quant = quant_desc.fake_quant
        self._axis = quant_desc.axis
        self._scale_amax = quant_desc.scale_amax
        self._learn_amax = quant_desc.learn_amax
        self._unsigned = quant_desc.unsigned
        self._narrow_range = quant_desc.narrow_range

        self._scale = None if not quant_desc.fake_quant else 1.
        self._disabled = disabled
        self._if_quant = if_quant
        self._if_clip = False
        self._if_calib = if_calib

        if quant_desc.amax is not None:
            self.register_buffer('_amax', torch.tensor(quant_desc.amax))

        # Clip module consumes a lot of memory, so only create it if learn_amax is True
        if self._learn_amax:
            init_amax = quant_desc.amax if quant_desc.amax is not None else 1.
            self.clip = Clip(-init_amax,
                             init_amax,
                             learn_min=True,
                             learn_max=True)
            # It makes more sense to enable clip stage (which learns amax) if learn_amax is true
            self.enable_clip()
        if if_clip:
            self.enable_clip()

        if quant_desc.calib_method == "histogram":
            logging.info("Creating histogram calibrator")
            self._calibrator = calib.HistogramCalibrator(
                num_bits=self._num_bits,
                axis=self._axis,
                unsigned=self._unsigned)
        elif quant_desc.calib_method == "max":
            logging.info("Creating Max calibrator")
            self._calibrator = calib.MaxCalibrator(num_bits=self._num_bits,
                                                   axis=self._axis,
                                                   unsigned=self._unsigned)
示例#5
0
    def test_fine_grain(self):
        axis = 0
        reducs_axis = (1, 2, 3)
        max_calibrator = calib.MaxCalibrator(8, axis, False)

        x_1 = torch.rand(31, 63, 7, 7).cuda()
        x_2 = torch.rand(31, 63, 7, 7).cuda()
        max_calibrator.collect(x_1)
        max_calibrator.collect(x_2)

        assert max_calibrator.compute_amax().shape[0] == 31

        test_utils.compare(max_calibrator.compute_amax(),
                           quant_utils.reduce_amax(torch.max(x_1, x_2), axis=reducs_axis),
                           atol=0, rtol=0, ctol=0)

        max_calibrator.reset()
        assert max_calibrator.compute_amax() is None
示例#6
0
 def test_print_calibrator(self):
     print(calib.MaxCalibrator(7, 1, False))
     hist_calibrator = calib.HistogramCalibrator(8, None, True)
     hist_calibrator.collect(torch.rand(10))
     print(hist_calibrator)
示例#7
0
 def test_repr(self):
     max_calibrator = calib.MaxCalibrator(8, None, False, track_amax=True)
     repr(max_calibrator)