def get_default_qat_qconfig(backend='fbgemm', version=1): # Histogram observer is too slow for quantization aware training if version is None: if backend == 'fbgemm': qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=True), weight=default_per_channel_weight_fake_quant) elif backend == 'qnnpack': qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=False), weight=default_weight_fake_quant) else: qconfig = default_qat_qconfig # Use the fused observer + fake_quant modules for doing QAT. if version == 1: if backend == 'fbgemm': qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=True), weight=default_fused_per_channel_wt_fake_quant) elif backend == 'qnnpack': qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=False), weight=default_fused_wt_fake_quant) else: qconfig = default_qat_qconfig_v2 return qconfig
def get_default_qat_qconfig(backend='fbgemm', version=1): """ Returns the default QAT qconfig for the specified backend. Args: * `backend`: a string representing the target backend. Currently supports `fbgemm`, `qnnpack` and `onednn`. * `version`: version, for backwards compatibility. Can be `None` or `1`. Return: qconfig """ # Histogram observer is too slow for quantization aware training if version == 0: if backend == 'fbgemm': qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=True), weight=default_per_channel_weight_fake_quant) elif backend == 'qnnpack': qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=False), weight=default_weight_fake_quant) elif backend == 'onednn': qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255), weight=default_per_channel_weight_fake_quant) else: qconfig = default_qat_qconfig # Use the fused observe + fake_quant modules for doing QAT. elif version == 1: if backend == 'fbgemm': qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=True), weight=default_fused_per_channel_wt_fake_quant) elif backend == 'qnnpack': qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=False), weight=default_fused_wt_fake_quant) elif backend == 'onednn': qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255), weight=default_fused_per_channel_wt_fake_quant) else: qconfig = default_qat_qconfig_v2 else: raise AssertionError("Version number: " + str(version) + "in get_default_qat_qconfig is not supported. Version number must be 0 or 1") return qconfig
import torch from torch.ao.quantization.qconfig import QConfig from torch.ao.quantization import MinMaxObserver from torch.ao.quantization.fake_quantize import FakeQuantize from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize """ Default symmetric fake_quant for activations. """ default_symmetric_fake_quant = FakeQuantize.with_args( observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.quint8) """ Default symmetric fake_quant for weights. """ default_weight_symmetric_fake_quant = FakeQuantize.with_args( observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.qint8) # uniform activation and weight, b=8 k=2 uniform_qconfig_8bit = QConfig( activation=default_symmetric_fake_quant, weight=default_weight_symmetric_fake_quant.with_args) # uniform activation, APoT weight, b=8 k=2 apot_weight_qconfig_8bit = QConfig( activation=default_symmetric_fake_quant.with_args, weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8)) # APoT activation and uniform weight, b=8 k=2