Esempio n. 1
0
    def test_max(self):
        torch.manual_seed(12345)
        ref_lenet = QuantLeNet()
        torch.manual_seed(12345)
        test_lenet = QuantLeNet()

        for module in ref_lenet.modules():
            if isinstance(module,
                          (quant_nn.QuantConv2d, quant_nn.QuantLinear)):
                module.weight_quantizer.enable_calib()
                module.weight_quantizer.disable_quant()
                module.weight_quantizer(module.weight)
                module.weight_quantizer.load_calib_amax()

        calib.calibrate_weights(test_lenet, method="max")

        for ref_module, test_module in zip(ref_lenet.modules(),
                                           test_lenet.modules()):
            if isinstance(ref_module,
                          (quant_nn.QuantConv2d, quant_nn.QuantLinear)):
                test_utils.compare(ref_module.weight_quantizer.amax,
                                   test_module.weight_quantizer.amax,
                                   rtol=0,
                                   atol=0,
                                   ctol=0)
                assert ref_module.weight_quantizer.amax.shape == test_module.weight_quantizer.amax.shape
Esempio n. 2
0
    def test_mse_with_axis(self):
        torch.manual_seed(12345)
        test_lenet = QuantLeNet()

        ref_calibrator = calib.HistogramCalibrator(8, None, False)

        calib.calibrate_weights(test_lenet, method="mse", perchannel=True)
        ref_calibrator.collect(test_lenet.conv2.weight[1])
        ref_amax = ref_calibrator.compute_amax("mse")
        test_utils.compare(ref_amax, test_lenet.conv2.weight_quantizer.amax[1], rtol=0, atol=0, ctol=0)
Esempio n. 3
0
    def test_percentile(self):
        torch.manual_seed(12345)
        test_lenet = QuantLeNet()
        test_percentile = 99.99

        ref_calibrator = calib.HistogramCalibrator(8, None, False)

        calib.calibrate_weights(test_lenet, method="percentile", perchannel=False, percentile=test_percentile)
        ref_calibrator.collect(test_lenet.conv1.weight)
        ref_amax = ref_calibrator.compute_amax("percentile", percentile=test_percentile)
        test_utils.compare(ref_amax, test_lenet.conv1.weight_quantizer.amax, rtol=0, atol=0, ctol=0)
Esempio n. 4
0
    def test_shape_with_axis(self):
        """Check calibrate_weight function returns same shape as TensorQuantizer"""
        torch.manual_seed(12345)
        ref_lenet = QuantLeNet()
        torch.manual_seed(12345)
        test_lenet = QuantLeNet()

        for module in ref_lenet.modules():
            if isinstance(module, (quant_nn.QuantConv2d, quant_nn.QuantLinear)):
                module.weight_quantizer.enable_calib()
                module.weight_quantizer.disable_quant()
                module.weight_quantizer(module.weight)
                module.weight_quantizer.load_calib_amax()

        calib.calibrate_weights(test_lenet, method="percentile")

        for ref_module, test_module in zip(ref_lenet.modules(), test_lenet.modules()):
            if isinstance(ref_module, (quant_nn.QuantConv2d, quant_nn.QuantLinear)):
                assert ref_module.weight_quantizer.amax.shape == test_module.weight_quantizer.amax.shape