def get_default_qat_qconfig(backend='fbgemm', version=1): # Histogram observer is too slow for quantization aware training if version is None: if backend == 'fbgemm': qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=True), weight=default_per_channel_weight_fake_quant) elif backend == 'qnnpack': qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=False), weight=default_weight_fake_quant) else: qconfig = default_qat_qconfig # Use the fused observer + fake_quant modules for doing QAT. if version == 1: if backend == 'fbgemm': qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=True), weight=default_fused_per_channel_wt_fake_quant) elif backend == 'qnnpack': qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=False), weight=default_fused_wt_fake_quant) else: qconfig = default_qat_qconfig_v2 return qconfig
def get_default_qat_qconfig(backend='fbgemm', version=1): """ Returns the default QAT qconfig for the specified backend. Args: * `backend`: a string representing the target backend. Currently supports `fbgemm`, `qnnpack` and `onednn`. * `version`: version, for backwards compatibility. Can be `None` or `1`. Return: qconfig """ # Histogram observer is too slow for quantization aware training if version == 0: if backend == 'fbgemm': qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=True), weight=default_per_channel_weight_fake_quant) elif backend == 'qnnpack': qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=False), weight=default_weight_fake_quant) elif backend == 'onednn': qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255), weight=default_per_channel_weight_fake_quant) else: qconfig = default_qat_qconfig # Use the fused observe + fake_quant modules for doing QAT. elif version == 1: if backend == 'fbgemm': qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=True), weight=default_fused_per_channel_wt_fake_quant) elif backend == 'qnnpack': qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=False), weight=default_fused_wt_fake_quant) elif backend == 'onednn': qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255), weight=default_fused_per_channel_wt_fake_quant) else: qconfig = default_qat_qconfig_v2 else: raise AssertionError("Version number: " + str(version) + "in get_default_qat_qconfig is not supported. Version number must be 0 or 1") return qconfig
raise AssertionError( "Version number: " + str(version) + "in get_default_qat_qconfig is not supported. Version number must be 0 or 1" ) return qconfig """ Default symmetric QAT qconfig for qnnpack. And its per channel weight variant. """ default_symmetric_qnnpack_qat_qconfig = QConfig( activation=FusedMovingAvgObsFakeQuantize.with_args( observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, reduce_range=False, eps=2**-12), weight=fused_wt_fake_quant_range_neg_127_to_127) default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig( activation=FusedMovingAvgObsFakeQuantize.with_args( observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, reduce_range=False, eps=2**-12), weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127)