def prepare_for_qat(model, quantize_weights_per_channel, fuse_relu):
    """Prepares model for quantization aware training"""

    # fuse models
    model.fuse_model(fuse_relu=fuse_relu)

    # set qconfig
    if quantize_weights_per_channel:
        qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
    else:
        print("Quantizating weights per tensor")
        qconfig = QConfig(activation=FakeQuantize.with_args(
            observer=MovingAverageMinMaxObserver,
            quant_min=0,
            quant_max=255,
            reduce_range=True),
                          weight=default_weight_fake_quant)
    model.qconfig = qconfig

    # equivalent to quantize.prepare, inplace. require for custom white list
    # propagate qconfig and add observers
    _propagate_qconfig_helper(model,
                              qconfig_dict={},
                              white_list=QAT_QCONFIG_PROPAGATE_WHITE_LIST)
    if not any(hasattr(m, "qconfig") and m.qconfig for m in model.modules()):
        print("None of the submodule got qconfig applied. Make sure you "
              "passed correct configuration through `qconfig_dict` or "
              "by assigning the `.qconfig` attribute directly on submodules")
    add_observer_(model)

    # convert modules to their QAT versions. should be sent to device after
    convert(model, QAT_QUANTIZED_MODULE_MAPPING, inplace=True)
Пример #2
0
    def test_fq_module(self, device, X):
        np.random.seed(NP_RANDOM_SEED)
        X, (scale, zero_point, axis, torch_type) = X
        quant_min = torch.iinfo(torch_type).min
        quant_max = torch.iinfo(torch_type).max

        X = to_tensor(X, device)
        X.requires_grad_()
        fq_module = FakeQuantize(default_per_channel_weight_observer,
                                 quant_min,
                                 quant_max,
                                 ch_axis=axis).to(device)
        Y_prime = fq_module(X)
        assert fq_module.scale is not None
        assert fq_module.zero_point is not None
        Y = _fake_quantize_per_channel_affine_reference(
            X, fq_module.scale, fq_module.zero_point, axis, quant_min,
            quant_max)
        np.testing.assert_allclose(Y.cpu().detach().numpy(),
                                   Y_prime.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)

        # Test backward
        dout = torch.rand(X.shape, dtype=torch.float, device=device)
        Y_prime.backward(dout)
        dX = _fake_quantize_per_channel_affine_grad_reference(
            dout, X, fq_module.scale, fq_module.zero_point, axis, quant_min,
            quant_max)
        np.testing.assert_allclose(dX.cpu().numpy(),
                                   X.grad.cpu().detach().numpy(),
                                   rtol=tolerance,
                                   atol=tolerance)
Пример #3
0
 def test_fq_serializable(self):
     observer = default_per_channel_weight_observer
     quant_min = -128
     quant_max = 127
     fq_module = FakeQuantize(observer, quant_min, quant_max)
     X = torch.tensor([[-5, -3.5, -2, 0, 3, 5, 7], [1, 3, 2, 5, 6.5, 8, 10]], dtype=torch.float32)
     y_ref = fq_module(X)
     state_dict = fq_module.state_dict()
     self.assertEqual(state_dict['scale'], [0.054902, 0.078431])
     self.assertEqual(state_dict['zero_point'], [0, 0])
     b = io.BytesIO()
     torch.save(state_dict, b)
     b.seek(0)
     loaded_dict = torch.load(b)
     for key in state_dict:
         self.assertEqual(state_dict[key], loaded_dict[key])
Пример #4
0
    def fake_quant_scriptable(self):
        observer = default_observer
        quant_min = 0
        quant_max = 255
        fq_module = FakeQuantize(observer, quant_min, quant_max)
        scripted_module = torch.jit.script(fq_module)

        X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)

        fq_module(X)
        scripted_module(X)
        self.assertEqual(fq_module.calculate_qparams(),
                         scripted_module.calculate_qparams())

        buf = io.BytesIO()
        torch.jit.save(scripted_module, buf)
        buf.seek(0)
        loaded_module = torch.jit.load(buf)
        self.assertEqual(fq_module.calculate_qparams(),
                         loaded_module.calculate_qparams())
Пример #5
0
    def test_fq_serializable(self):
        observer = default_observer
        quant_min = 0
        quant_max = 255
        fq_module = FakeQuantize(observer, quant_min, quant_max)
        X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)
        y_ref = fq_module(X)
        state_dict = fq_module.state_dict()
        self.assertEqual(state_dict['scale'], 0.094488)
        self.assertEqual(state_dict['zero_point'], 53)
        b = io.BytesIO()
        torch.save(state_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        loaded_fq_module = FakeQuantize(observer, quant_min, quant_max)
        loaded_fq_module.load_state_dict(loaded_dict)
        for key in state_dict:
            self.assertEqual(state_dict[key], loaded_fq_module.state_dict()[key])

        self.assertEqual(loaded_fq_module.calculate_qparams(), fq_module.calculate_qparams())
Пример #6
0
    def test_compare_fused_obs_fq_oss_module(self, device):
        mod = FusedMovingAvgObsFakeQuantize()
        torch.quantization.enable_fake_quant(mod)
        torch.quantization.enable_observer(mod)
        mod.to(device)

        mod_ref = FakeQuantize()
        torch.quantization.enable_fake_quant(mod_ref)
        torch.quantization.enable_observer(mod_ref)
        mod_ref.to(device)

        for i in range(10):
            x = torch.randn(5, 5, device=device)
            out = mod(x)
            out_ref = mod_ref(x)
            torch.testing.assert_allclose(out, out_ref)
            torch.testing.assert_allclose(
                mod_ref.activation_post_process.min_val,
                mod.activation_post_process.min_val,
            )
            torch.testing.assert_allclose(
                mod_ref.activation_post_process.max_val,
                mod.activation_post_process.max_val,
            )
Пример #7
0
    activation=MovingAverageMinMaxObserver.with_args(
        dtype=torch.qint8,
        quant_min=-128,
        quant_max=127,
        qscheme=torch.per_tensor_affine,
    ),
    weight=MinMaxObserver.with_args(dtype=torch.qint8,
                                    quant_min=-127,
                                    quant_max=127,
                                    qscheme=torch.per_tensor_symmetric),
)
_TFLITE_QAT_QCONFIG = QConfig(
    activation=FakeQuantize.with_args(
        observer=MovingAverageMinMaxObserver,
        dtype=torch.qint8,
        quant_min=-128,
        quant_max=127,
        qscheme=torch.per_tensor_affine,
    ),
    weight=FakeQuantize.with_args(observer=MinMaxObserver,
                                  dtype=torch.qint8,
                                  quant_min=-127,
                                  quant_max=127,
                                  qscheme=torch.per_tensor_symmetric),
)
_ONNX_QCONFIG = QConfig(
    activation=MinMaxObserver.with_args(
        quant_min=0,
        quant_max=255,
        reduce_range=True,
    ),