コード例 #1
0
def get_default_qat_qconfig(backend='fbgemm', version=1):
    # Histogram observer is too slow for quantization aware training
    if version is None:
        if backend == 'fbgemm':
            qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
                                                                quant_min=0,
                                                                quant_max=255,
                                                                reduce_range=True),
                              weight=default_per_channel_weight_fake_quant)
        elif backend == 'qnnpack':
            qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
                                                                quant_min=0,
                                                                quant_max=255,
                                                                reduce_range=False),
                              weight=default_weight_fake_quant)
        else:
            qconfig = default_qat_qconfig
    # Use the fused observer + fake_quant modules for doing QAT.
    if version == 1:
        if backend == 'fbgemm':
            qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
                                                                                 quant_min=0,
                                                                                 quant_max=255,
                                                                                 reduce_range=True),
                              weight=default_fused_per_channel_wt_fake_quant)
        elif backend == 'qnnpack':
            qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
                                                                                 quant_min=0,
                                                                                 quant_max=255,
                                                                                 reduce_range=False),
                              weight=default_fused_wt_fake_quant)
        else:
            qconfig = default_qat_qconfig_v2
    return qconfig
コード例 #2
0
ファイル: qconfig.py プロジェクト: JonghyunBae/FlashNeuron
def get_default_qat_qconfig(backend='fbgemm', version=1):
    """
    Returns the default QAT qconfig for the specified backend.

    Args:
      * `backend`: a string representing the target backend. Currently supports `fbgemm`,
        `qnnpack` and `onednn`.
      * `version`: version, for backwards compatibility. Can be `None` or `1`.

    Return:
        qconfig
    """
    # Histogram observer is too slow for quantization aware training
    if version == 0:
        if backend == 'fbgemm':
            qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
                                                                quant_min=0,
                                                                quant_max=255,
                                                                reduce_range=True),
                              weight=default_per_channel_weight_fake_quant)
        elif backend == 'qnnpack':
            qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
                                                                quant_min=0,
                                                                quant_max=255,
                                                                reduce_range=False),
                              weight=default_weight_fake_quant)
        elif backend == 'onednn':
            qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
                                                                quant_min=0,
                                                                quant_max=255),
                              weight=default_per_channel_weight_fake_quant)
        else:
            qconfig = default_qat_qconfig
    # Use the fused observe + fake_quant modules for doing QAT.
    elif version == 1:
        if backend == 'fbgemm':
            qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
                                                                                 quant_min=0,
                                                                                 quant_max=255,
                                                                                 reduce_range=True),
                              weight=default_fused_per_channel_wt_fake_quant)
        elif backend == 'qnnpack':
            qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
                                                                                 quant_min=0,
                                                                                 quant_max=255,
                                                                                 reduce_range=False),
                              weight=default_fused_wt_fake_quant)
        elif backend == 'onednn':
            qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
                                                                                 quant_min=0,
                                                                                 quant_max=255),
                              weight=default_fused_per_channel_wt_fake_quant)
        else:
            qconfig = default_qat_qconfig_v2
    else:
        raise AssertionError("Version number: " + str(version) +
                             "in get_default_qat_qconfig is not supported. Version number must be 0 or 1")

    return qconfig
コード例 #3
0
ファイル: qconfig.py プロジェクト: huaxz1986/pytorch
import torch
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization import MinMaxObserver
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize
"""
Default symmetric fake_quant for activations.
"""
default_symmetric_fake_quant = FakeQuantize.with_args(
    observer=MinMaxObserver,
    qscheme=torch.per_tensor_symmetric,
    dtype=torch.quint8)
"""
Default symmetric fake_quant for weights.
"""
default_weight_symmetric_fake_quant = FakeQuantize.with_args(
    observer=MinMaxObserver,
    qscheme=torch.per_tensor_symmetric,
    dtype=torch.qint8)

# uniform activation and weight, b=8 k=2
uniform_qconfig_8bit = QConfig(
    activation=default_symmetric_fake_quant,
    weight=default_weight_symmetric_fake_quant.with_args)

# uniform activation, APoT weight, b=8 k=2
apot_weight_qconfig_8bit = QConfig(
    activation=default_symmetric_fake_quant.with_args,
    weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8))

# APoT activation and uniform weight, b=8 k=2