def prepare_for_qat(model, quantize_weights_per_channel, fuse_relu):
    """Prepares model for quantization aware training"""

    # fuse models
    model.fuse_model(fuse_relu=fuse_relu)

    # set qconfig
    if quantize_weights_per_channel:
        qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
    else:
        print("Quantizating weights per tensor")
        qconfig = QConfig(activation=FakeQuantize.with_args(
            observer=MovingAverageMinMaxObserver,
            quant_min=0,
            quant_max=255,
            reduce_range=True),
                          weight=default_weight_fake_quant)
    model.qconfig = qconfig

    # equivalent to quantize.prepare, inplace. require for custom white list
    # propagate qconfig and add observers
    _propagate_qconfig_helper(model,
                              qconfig_dict={},
                              white_list=QAT_QCONFIG_PROPAGATE_WHITE_LIST)
    if not any(hasattr(m, "qconfig") and m.qconfig for m in model.modules()):
        print("None of the submodule got qconfig applied. Make sure you "
              "passed correct configuration through `qconfig_dict` or "
              "by assigning the `.qconfig` attribute directly on submodules")
    add_observer_(model)

    # convert modules to their QAT versions. should be sent to device after
    convert(model, QAT_QUANTIZED_MODULE_MAPPING, inplace=True)
Esempio n. 2
0
    activation=MovingAverageMinMaxObserver.with_args(
        dtype=torch.qint8,
        quant_min=-128,
        quant_max=127,
        qscheme=torch.per_tensor_affine,
    ),
    weight=MinMaxObserver.with_args(dtype=torch.qint8,
                                    quant_min=-127,
                                    quant_max=127,
                                    qscheme=torch.per_tensor_symmetric),
)
_TFLITE_QAT_QCONFIG = QConfig(
    activation=FakeQuantize.with_args(
        observer=MovingAverageMinMaxObserver,
        dtype=torch.qint8,
        quant_min=-128,
        quant_max=127,
        qscheme=torch.per_tensor_affine,
    ),
    weight=FakeQuantize.with_args(observer=MinMaxObserver,
                                  dtype=torch.qint8,
                                  quant_min=-127,
                                  quant_max=127,
                                  qscheme=torch.per_tensor_symmetric),
)
_ONNX_QCONFIG = QConfig(
    activation=MinMaxObserver.with_args(
        quant_min=0,
        quant_max=255,
        reduce_range=True,
    ),