def prepare_for_qat(model, quantize_weights_per_channel, fuse_relu): """Prepares model for quantization aware training""" # fuse models model.fuse_model(fuse_relu=fuse_relu) # set qconfig if quantize_weights_per_channel: qconfig = torch.quantization.get_default_qat_qconfig("fbgemm") else: print("Quantizating weights per tensor") qconfig = QConfig(activation=FakeQuantize.with_args( observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=True), weight=default_weight_fake_quant) model.qconfig = qconfig # equivalent to quantize.prepare, inplace. require for custom white list # propagate qconfig and add observers _propagate_qconfig_helper(model, qconfig_dict={}, white_list=QAT_QCONFIG_PROPAGATE_WHITE_LIST) if not any(hasattr(m, "qconfig") and m.qconfig for m in model.modules()): print("None of the submodule got qconfig applied. Make sure you " "passed correct configuration through `qconfig_dict` or " "by assigning the `.qconfig` attribute directly on submodules") add_observer_(model) # convert modules to their QAT versions. should be sent to device after convert(model, QAT_QUANTIZED_MODULE_MAPPING, inplace=True)
def test_fq_module(self, device, X): np.random.seed(NP_RANDOM_SEED) X, (scale, zero_point, axis, torch_type) = X quant_min = torch.iinfo(torch_type).min quant_max = torch.iinfo(torch_type).max X = to_tensor(X, device) X.requires_grad_() fq_module = FakeQuantize(default_per_channel_weight_observer, quant_min, quant_max, ch_axis=axis).to(device) Y_prime = fq_module(X) assert fq_module.scale is not None assert fq_module.zero_point is not None Y = _fake_quantize_per_channel_affine_reference( X, fq_module.scale, fq_module.zero_point, axis, quant_min, quant_max) np.testing.assert_allclose(Y.cpu().detach().numpy(), Y_prime.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) # Test backward dout = torch.rand(X.shape, dtype=torch.float, device=device) Y_prime.backward(dout) dX = _fake_quantize_per_channel_affine_grad_reference( dout, X, fq_module.scale, fq_module.zero_point, axis, quant_min, quant_max) np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
def test_fq_serializable(self): observer = default_per_channel_weight_observer quant_min = -128 quant_max = 127 fq_module = FakeQuantize(observer, quant_min, quant_max) X = torch.tensor([[-5, -3.5, -2, 0, 3, 5, 7], [1, 3, 2, 5, 6.5, 8, 10]], dtype=torch.float32) y_ref = fq_module(X) state_dict = fq_module.state_dict() self.assertEqual(state_dict['scale'], [0.054902, 0.078431]) self.assertEqual(state_dict['zero_point'], [0, 0]) b = io.BytesIO() torch.save(state_dict, b) b.seek(0) loaded_dict = torch.load(b) for key in state_dict: self.assertEqual(state_dict[key], loaded_dict[key])
def fake_quant_scriptable(self): observer = default_observer quant_min = 0 quant_max = 255 fq_module = FakeQuantize(observer, quant_min, quant_max) scripted_module = torch.jit.script(fq_module) X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32) fq_module(X) scripted_module(X) self.assertEqual(fq_module.calculate_qparams(), scripted_module.calculate_qparams()) buf = io.BytesIO() torch.jit.save(scripted_module, buf) buf.seek(0) loaded_module = torch.jit.load(buf) self.assertEqual(fq_module.calculate_qparams(), loaded_module.calculate_qparams())
def test_fq_serializable(self): observer = default_observer quant_min = 0 quant_max = 255 fq_module = FakeQuantize(observer, quant_min, quant_max) X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32) y_ref = fq_module(X) state_dict = fq_module.state_dict() self.assertEqual(state_dict['scale'], 0.094488) self.assertEqual(state_dict['zero_point'], 53) b = io.BytesIO() torch.save(state_dict, b) b.seek(0) loaded_dict = torch.load(b) loaded_fq_module = FakeQuantize(observer, quant_min, quant_max) loaded_fq_module.load_state_dict(loaded_dict) for key in state_dict: self.assertEqual(state_dict[key], loaded_fq_module.state_dict()[key]) self.assertEqual(loaded_fq_module.calculate_qparams(), fq_module.calculate_qparams())
def test_compare_fused_obs_fq_oss_module(self, device): mod = FusedMovingAvgObsFakeQuantize() torch.quantization.enable_fake_quant(mod) torch.quantization.enable_observer(mod) mod.to(device) mod_ref = FakeQuantize() torch.quantization.enable_fake_quant(mod_ref) torch.quantization.enable_observer(mod_ref) mod_ref.to(device) for i in range(10): x = torch.randn(5, 5, device=device) out = mod(x) out_ref = mod_ref(x) torch.testing.assert_allclose(out, out_ref) torch.testing.assert_allclose( mod_ref.activation_post_process.min_val, mod.activation_post_process.min_val, ) torch.testing.assert_allclose( mod_ref.activation_post_process.max_val, mod.activation_post_process.max_val, )
activation=MovingAverageMinMaxObserver.with_args( dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_tensor_affine, ), weight=MinMaxObserver.with_args(dtype=torch.qint8, quant_min=-127, quant_max=127, qscheme=torch.per_tensor_symmetric), ) _TFLITE_QAT_QCONFIG = QConfig( activation=FakeQuantize.with_args( observer=MovingAverageMinMaxObserver, dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_tensor_affine, ), weight=FakeQuantize.with_args(observer=MinMaxObserver, dtype=torch.qint8, quant_min=-127, quant_max=127, qscheme=torch.per_tensor_symmetric), ) _ONNX_QCONFIG = QConfig( activation=MinMaxObserver.with_args( quant_min=0, quant_max=255, reduce_range=True, ),