Esempio n. 1
0
    def test_qconfig_dict(self):
        data = [(torch.randn(10, 5, dtype=torch.float) * 20, 1)]

        # Eager mode
        qconfig = QConfig(activation=Observer, weight=WeightObserver)
        eager_module = AnnotatedNestedModel()
        eager_module.fc3.qconfig = qconfig
        eager_module.sub2.fc1.qconfig = qconfig
        # Assign weights
        eager_module.sub1.fc.weight.data.fill_(1.0)
        eager_module.sub2.fc1.module.weight.data.fill_(1.0)
        eager_module.sub2.fc2.weight.data.fill_(1.0)
        eager_module.fc3.module.weight.data.fill_(1.0)

        script_module = torch.jit.script(NestedModel())
        # Copy weights for eager_module
        script_module.sub1.fc.weight = eager_module.sub1.fc.weight
        script_module.sub2.fc1.weight = eager_module.sub2.fc1.module.weight
        script_module.sub2.fc2.weight = eager_module.sub2.fc2.weight
        script_module.fc3.weight = eager_module.fc3.module.weight

        # Quantize eager module
        quantized_eager_module = quantize(eager_module, default_eval_fn, data)

        def get_forward(m):
            return m._c._get_method('forward')

        # Quantize script_module
        torch._C._jit_pass_constant_propagation(
            get_forward(script_module).graph)

        ScriptedObserver = torch.jit.script(Observer())
        ScriptedWeightObserver = torch.jit.script(WeightObserver())
        scripted_qconfig = QConfig(activation=ScriptedObserver._c,
                                   weight=ScriptedWeightObserver._c)
        qconfig_dict = {'sub2.fc1': scripted_qconfig, 'fc3': scripted_qconfig}
        torch._C._jit_pass_insert_observers(script_module._c, "forward",
                                            qconfig_dict)

        # Run script_module and Collect statistics
        get_forward(script_module)(data[0][0])

        # Insert quantize and dequantize calls
        script_module._c = torch._C._jit_pass_insert_quant_dequant(
            script_module._c, "forward")
        # Note that observer modules are not removed right now
        torch._C._jit_pass_quant_fusion(
            script_module._c._get_method('forward').graph)
        get_forward(script_module)(data[0][0])
        eager_result = quantized_eager_module(data[0][0])
        script_result = get_forward(script_module)(data[0][0])
        self.assertEqual(eager_result, script_result)
Esempio n. 2
0
    def test_single_layer(self):
        r"""Compare the result of quantizing single linear layer in
        eager mode and graph mode
        """
        # eager mode
        annotated_linear_model = AnnotatedSingleLayerLinearModel()
        linear_model = SingleLayerLinearModel()
        # copy the weight from eager mode so that we can
        # compare the result of the two quantized models later
        linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach())
        linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach())
        model_eager = quantize(annotated_linear_model, test_only_eval_fn,
                               self.calib_data)

        qconfig_dict = {
            '': QConfig(
                activation=default_observer,
                weight=default_weight_observer)
        }
        model_script = quantize_script(
            torch.jit.script(linear_model),
            qconfig_dict,
            test_only_eval_fn,
            [self.calib_data],
            inplace=False)
        result_eager = model_eager(self.calib_data[0][0])
        torch._C._jit_pass_quant_fusion(model_script._c._get_module('fc1')._get_method('forward').graph)
        result_script = model_script._c._get_method('forward')(self.calib_data[0][0])
        self.assertEqual(result_eager, result_script)
    def test_save_load_state_dict_script(self):
        """
        Tests that we can save and load state_dict for observers that are scripted
        in a quantized model.
        """
        obs_list = [
            MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver
        ]

        for obs in obs_list:
            model = SingleLayerLinearModel().eval()
            qconfig = QConfig(activation=default_observer, weight=obs)
            qconfig_dict = {'': qconfig}
            scripted = torch.jit.script(model)
            scripted = torch.quantization.prepare_jit(scripted, qconfig_dict)
            x = torch.rand(5, 5)
            scripted(x)
            obs_dict = torch.quantization.get_observer_state_dict(scripted)

            # Load stats
            scripted_2 = torch.jit.script(model)
            scripted_2 = torch.quantization.prepare_jit(
                scripted_2, qconfig_dict)
            torch.quantization.load_observer_state_dict(scripted_2, obs_dict)
            # Verify that state_dict matches exactly with original one.
            self.assertEqual(scripted.state_dict(), scripted_2.state_dict())
def prepare_for_qat(model, quantize_weights_per_channel, fuse_relu):
    """Prepares model for quantization aware training"""

    # fuse models
    model.fuse_model(fuse_relu=fuse_relu)

    # set qconfig
    if quantize_weights_per_channel:
        qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
    else:
        print("Quantizating weights per tensor")
        qconfig = QConfig(activation=FakeQuantize.with_args(
            observer=MovingAverageMinMaxObserver,
            quant_min=0,
            quant_max=255,
            reduce_range=True),
                          weight=default_weight_fake_quant)
    model.qconfig = qconfig

    # equivalent to quantize.prepare, inplace. require for custom white list
    # propagate qconfig and add observers
    _propagate_qconfig_helper(model,
                              qconfig_dict={},
                              white_list=QAT_QCONFIG_PROPAGATE_WHITE_LIST)
    if not any(hasattr(m, "qconfig") and m.qconfig for m in model.modules()):
        print("None of the submodule got qconfig applied. Make sure you "
              "passed correct configuration through `qconfig_dict` or "
              "by assigning the `.qconfig` attribute directly on submodules")
    add_observer_(model)

    # convert modules to their QAT versions. should be sent to device after
    convert(model, QAT_QUANTIZED_MODULE_MAPPING, inplace=True)
Esempio n. 5
0
    def test_fusion_conv_with_bias(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ModelForFusionWithBias().train()
                # output with no fusion.
                out_ref = model(self.img_data_2d[0][0])

                model.qconfig = QConfig(activation=torch.nn.Identity,
                                        weight=torch.nn.Identity)
                model = fuse_modules(
                    model, [["conv1", "bn1", "relu1"], ["conv2", "bn2"]])
                prep_model = prepare_qat(model, inplace=False)
                # output with fusion but no observers.
                out_fused = prep_model(self.img_data_2d[0][0])
                self.assertEqual(out_ref, out_fused)

                model.qconfig = torch.quantization.get_default_qconfig(qengine)
                prepare_qat(model, inplace=True)

                model(self.img_data_2d[0][0])

                def checkQAT(model):
                    self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
                    self.assertEqual(type(model.bn1), nn.Identity)
                    self.assertEqual(type(model.relu1), nn.Identity)
                    self.assertEqual(type(model.conv2), nniqat.ConvBn2d)
                    self.assertEqual(type(model.bn2), nn.Identity)

                checkQAT(model)
Esempio n. 6
0
def quantize_statically(model, inputs, data_loader, linear_only=False):
    if (hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder)
            and linear_only):
        qconfig = QConfig(
            activation=HistogramObserver.with_args(reduce_range=False),
            weight=default_weight_observer,
        )
        qconfig_dict = {"": None}
        for layer_idx in range(len(model.encoder.encoder.transformer.layers)):
            qconfig_dict[
                "encoder.encoder.transformer.layers.{}.attention.input_projection"
                .format(layer_idx)] = qconfig
            qconfig_dict[
                "encoder.encoder.transformer.layers.{}.attention.output_projection"
                .format(layer_idx)] = qconfig
            for mlp_idx, m in enumerate(model.encoder.encoder.transformer.
                                        layers[layer_idx].residual_mlp.mlp):
                if type(m) == torch.nn.Linear:
                    qconfig_dict[
                        "encoder.encoder.transformer.layers.{}.residual_mlp.mlp.{}"
                        .format(layer_idx, mlp_idx)] = qconfig
        trace = model.graph_mode_quantize(inputs,
                                          data_loader,
                                          qconfig_dict=qconfig_dict,
                                          force_quantize=True)
    else:
        trace = model.graph_mode_quantize(inputs, data_loader)

    return trace
Esempio n. 7
0
    def test_nested3(self):
        r"""More complicated nested test case with child qconfig overrides
        parent qconfig
        """
        model = NestedModel().eval()
        custum_options = {
            'dtype': torch.quint8,
            'qscheme': torch.per_tensor_affine
        }
        custom_qconfig = QConfig(weight=default_weight_observer(),
                                 activation=default_observer(**custum_options))
        qconfig_dict = {
            'fc3': default_qconfig,
            'sub2': default_qconfig,
            'sub2.fc1': custom_qconfig
        }
        model = prepare_dynamic(model, qconfig_dict)

        convert_dynamic(model)

        def checkQuantized(model):
            self.checkDynamicQuantizedLinear(model.sub2.fc1)
            self.checkDynamicQuantizedLinear(model.sub2.fc2)
            self.checkDynamicQuantizedLinear(model.fc3)

        checkQuantized(model)

        # test one line API
        model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
        checkQuantized(model)
Esempio n. 8
0
def quantize_fx(model, inputs, data_loader, dynamic=True):

    if hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder):

        static = not dynamic

        if dynamic:
            qconfig = per_channel_dynamic_qconfig
        else:
            qconfig = QConfig(
                activation=HistogramObserver.with_args(reduce_range=False),
                weight=default_weight_observer,
            )

        # Only linear layers
        qconfig_dict = {"": None}
        qconfig_dict["object_type"] = [(torch.nn.Linear, qconfig)]

        def calibrate(model, loader, max_samples=-1):
            model.eval()
            with torch.no_grad():
                for (idx, d) in enumerate(loader):
                    print("Running sample input #" + str(idx))
                    model(d[1]["tokens"])
                    if idx == max_samples:
                        break

        prepared_model = prepare_fx(
            model.encoder.encoder.transformer.layers.layers,
            qconfig_dict)  # fuse modules and insert observers

        model.encoder.encoder.transformer.layers.layers = prepared_model
        if static:
            calibrate(model, data_loader)  # run calibration on sample data
        model.encoder.encoder.transformer.layers.layers = convert_fx(
            prepared_model)

        # Trace the submodule in order to fix the interface
        if static:
            input1 = torch.randn([2, 1, 1024], dtype=torch.float)
            input2 = torch.randn([1, 2]).bool()
            traced = torch.jit.trace(
                model.encoder.encoder.transformer.layers.layers,
                (input1, input2))
            model.encoder.encoder.transformer.layers.layers = traced

        # Trace the overall module
        trace = model.trace(inputs)

        return trace
Esempio n. 9
0
    def test_nested3(self):
        r"""More complicated nested test case with child qconfig overrides
        parent qconfig
        """
        model = NestedModel().eval()
        custum_options = {
            'dtype': torch.quint8,
            'qscheme': torch.per_tensor_affine
        }
        custom_qconfig = QConfig(weight=default_weight_observer(),
                                 activation=default_observer(**custum_options))
        qconfig_dict = {
            'fc3': default_qconfig,
            'sub2': default_qconfig,
            'sub2.fc1': custom_qconfig
        }
        model = prepare(model, qconfig_dict)

        def checkPrepModules(model, before_calib=False):
            if before_calib:
                self.checkObservers(model)
            self.checkNoPrepModules(model)
            self.checkNoPrepModules(model.sub1)
            self.checkNoPrepModules(model.sub1.fc)
            self.checkNoPrepModules(model.sub1.relu)
            self.checkNoPrepModules(model.sub2)
            self.checkHasPrepModules(model.sub2.fc1)
            self.checkHasPrepModules(model.sub2.fc2)
            self.checkHasPrepModules(model.fc3)

        checkPrepModules(model, True)

        test_only_eval_fn(model, self.calib_data)
        convert(model)

        def checkQuantized(model):
            checkPrepModules(model)
            self.checkQuantizedLinear(model.sub2.fc1)
            self.checkQuantizedLinear(model.sub2.fc2)
            self.checkQuantizedLinear(model.fc3)
            test_only_eval_fn(model, self.calib_data)

        checkQuantized(model)

        # test one line API
        model = quantize(NestedModel().eval(), test_only_eval_fn,
                         self.calib_data, qconfig_dict)
        checkQuantized(model)
Esempio n. 10
0
    def __init__(self):
        super().__init__()
        self.sub1 = LinearReluModel()
        self.sub2 = TwoLayerLinearModel()
        self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
        self.fc3.qconfig = default_qconfig
        self.sub2.qconfig = default_qconfig

        custom_options = {
            'dtype': torch.quint8,
            'qscheme': torch.per_tensor_affine
        }
        custom_qconfig = QConfig(activation=default_observer.with_args(**custom_options),
                                 weight=default_weight_observer)
        self.sub2.fc1.qconfig = custom_qconfig

        self.sub2.fc1 = QuantWrapper(self.sub2.fc1)
        self.sub2.fc2 = QuantWrapper(self.sub2.fc2)
Esempio n. 11
0
def quantize_statically(model,
                        inputs,
                        data_loader,
                        linear_only=False,
                        module_swap=False):
    log_feature_usage("export.quantize.statically")
    if (hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder)
            and linear_only):
        log_accelerator_feature_usage("quantize.statically")
        qconfig = QConfig(
            activation=HistogramObserver.with_args(reduce_range=False),
            weight=default_weight_observer,
        )
        qconfig_dict = {"": None}
        if module_swap:
            layers = model.encoder.encoder.transformer.layers.layers
            layers_str = "encoder.encoder.transformer.layers.layers"
        else:
            layers = model.encoder.encoder.transformer.layers
            layers_str = "encoder.encoder.transformer.layers"

        # skip first layer
        for layer_idx in range(1, len(layers)):
            qconfig_dict[
                layers_str +
                ".{}.attention.input_projection".format(layer_idx)] = qconfig
            qconfig_dict[
                layers_str +
                ".{}.attention.output_projection".format(layer_idx)] = qconfig
            for mlp_idx, m in enumerate(layers[layer_idx].residual_mlp.mlp):
                # Only quantize first linear otherwise there are accuarcy issues
                if type(m) == torch.nn.Linear and mlp_idx < 1:
                    qconfig_dict[layers_str + ".{}.residual_mlp.mlp.{}".format(
                        layer_idx, mlp_idx)] = qconfig
        trace = model.graph_mode_quantize(inputs,
                                          data_loader,
                                          qconfig_dict=qconfig_dict,
                                          force_quantize=True)
    else:
        trace = model.graph_mode_quantize(inputs, data_loader)

    return trace
Esempio n. 12
0
    def test_single_layer(self):
        r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
        to nnq.Linear which is the quantized version of the module
        """
        # eager mode
        model_eager = quantize(AnnotatedSingleLayerLinearModel(),
                               test_only_eval_fn, self.calib_data)

        qconfig_dict = {
            '':
            QConfig(activation=default_observer,
                    weight=default_weight_observer)
        }
        model_script = quantize_script(
            torch.jit.script(SingleLayerLinearModel()), qconfig_dict,
            test_only_eval_fn, [self.calib_data])
        result_eager = model_eager(self.calib_data[0][0])
        result_script = model_script._c._get_method('forward')(
            self.calib_data[0][0])
        self.assertEqual(result_eager, result_script)
Esempio n. 13
0
    def test_default(self):
        class TestM(nn.Module):
            def __init__(self, qconfig):
                super(TestM, self).__init__()
                self.conv = nn.Conv2d(3, 1, 3).float()
                self.conv.weight.data.fill_(1.0)
                self.conv.bias.data.fill_(0.01)
                self.qconfig = qconfig
                self.quant = QuantStub()
                self.dequant = DeQuantStub()

            def forward(self, x):
                return self.dequant(self.conv(self.quant(x)))

        class TestScriptM(torch.jit.ScriptModule):
            def __init__(self):
                super(TestScriptM, self).__init__()
                self.conv = nn.Conv2d(3, 1, 3).float()
                self.conv.bias.data.fill_(0.01)

            @torch.jit.script_method
            def forward(self, x):
                y = self.conv(x)
                return y

        # Test Data
        data = [(torch.randn(10, 3, 10, 10, dtype=torch.float), 1)]

        # Eager mode
        fake_qconfig = QConfig(activation=Observer, weight=WeightObserver)
        eager_module = TestM(fake_qconfig)
        # Script mode
        script_module = TestScriptM()
        script_module.conv.weight = torch.nn.Parameter(
            eager_module.conv.weight.detach())
        quantized_eager_module = quantize(eager_module, default_eval_fn, data)

        def get_forward(m):
            return m._c._get_method('forward')

        # TODO: test jit.script as well
        torch._C._jit_pass_constant_propagation(
            get_forward(script_module).graph)

        ScriptedObserver = torch.jit.script(Observer())
        ScriptedWeightObserver = torch.jit.script(WeightObserver())
        qconfig_dict = {
            '':
            QConfig(activation=ScriptedObserver._c,
                    weight=ScriptedWeightObserver._c)
        }
        torch._C._jit_pass_insert_observers(script_module._c, "forward",
                                            qconfig_dict)
        # Run ScriptM Model and Collect statistics
        get_forward(script_module)(data[0][0])

        # Insert quantize and dequantize calls
        script_module._c = torch._C._jit_pass_insert_quant_dequant(
            script_module._c, "forward")
        # Note that observer modules are not removed right now
        torch._C._jit_pass_quant_fusion(
            script_module._c._get_method('forward').graph)
        get_forward(script_module)(data[0][0])
        eager_result = quantized_eager_module(data[0][0])
        script_result = get_forward(script_module)(data[0][0])
        self.assertEqual(eager_result, script_result)
Esempio n. 14
0
def quantize_fx(model, inputs, data_loader, dynamic=True, selective=False):

    if hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder):

        static = not dynamic

        if dynamic:
            qconfig = per_channel_dynamic_qconfig
        else:
            qconfig = QConfig(
                activation=HistogramObserver.with_args(reduce_range=False),
                weight=default_weight_observer,
            )

        # Only linear layers
        qconfig_dict = {"": None}
        if static and selective:
            qconfig_dict["module_name"] = []
            layers = model.encoder.encoder.transformer.layers.layers.layers
            layers_str = "layers"
            # skip first layer
            for layer_idx in range(1, len(layers)):
                qconfig_dict["module_name"].append((
                    layers_str +
                    ".{}.attention.input_projection".format(layer_idx),
                    qconfig,
                ))
                qconfig_dict["module_name"].append((
                    layers_str +
                    ".{}.attention.output_projection".format(layer_idx),
                    qconfig,
                ))
                for mlp_idx, m in enumerate(
                        layers[layer_idx].residual_mlp.mlp):
                    # Only quantize first linear otherwise there are accuarcy issues with static quantization
                    if type(m) == torch.nn.Linear and mlp_idx < 1:
                        qconfig_dict["module_name"].append((
                            layers_str + ".{}.residual_mlp.mlp.{}".format(
                                layer_idx, mlp_idx),
                            qconfig,
                        ))
        else:
            qconfig_dict["object_type"] = [(torch.nn.Linear, qconfig)]

        def calibrate(model, loader, max_samples=-1):
            model.eval()
            with torch.no_grad():
                for (idx, d) in enumerate(loader):
                    print("Running sample input #" + str(idx))
                    model(d[1]["tokens"])
                    if idx == max_samples:
                        break

        prepared_model = prepare_fx(
            model.encoder.encoder.transformer.layers.layers,
            qconfig_dict)  # fuse modules and insert observers

        model.encoder.encoder.transformer.layers.layers = prepared_model
        if static:
            calibrate(model, data_loader)  # run calibration on sample data
        model.encoder.encoder.transformer.layers.layers = convert_fx(
            prepared_model)

        # Trace the submodule in order to fix the interface
        if static:
            input1 = torch.randn([2, 1, 1024], dtype=torch.float)
            input2 = torch.randn([1, 2]).bool()
            traced = torch.jit.trace(
                model.encoder.encoder.transformer.layers.layers,
                (input1, input2))
            model.encoder.encoder.transformer.layers.layers = traced

        # Trace the overall module
        trace = model.trace(inputs)

        return trace
Esempio n. 15
0
    def test_compare_qparam_eager_script_default(self):
        class Observer(torch.nn.Module):
            __annotations__ = {
                'scale': Optional[torch.Tensor],
                'zero_point': Optional[torch.Tensor]
            }

            def __init__(self):
                super(Observer, self).__init__()
                self.dtype = torch.quint8
                self.qscheme = torch.per_tensor_affine
                self.scale, self.zero_point = None, None

            def forward(self, x):
                self.scale = torch.tensor([2.0])
                self.zero_point = torch.tensor([3])
                return x

            @torch.jit.export
            def calculate_qparams(self):
                return self.scale, self.zero_point

        class WeightObserver(Observer):
            def __init__(self):
                super(WeightObserver, self).__init__()
                self.dtype = torch.qint8

        class TestM(nn.Module):
            def __init__(self, qconfig):
                super(TestM, self).__init__()
                self.conv = nn.Conv2d(3, 1, 3).float()
                self.conv.weight.data.fill_(1.0)
                self.conv.bias.data.fill_(0.01)
                self.qconfig = qconfig
                self.quant = QuantStub()
                self.dequant = DeQuantStub()

            def forward(self, x):
                return self.dequant(self.conv(self.quant(x)))

        class TestScriptM(torch.jit.ScriptModule):
            def __init__(self):
                super(TestScriptM, self).__init__()
                self.conv = nn.Conv2d(3, 1, 3).float()
                self.conv.bias.data.fill_(0.01)

            @torch.jit.script_method
            def forward(self, x):
                y = self.conv(x)
                return y

        # Test Data
        data = [(torch.randn(10, 3, 10, 10, dtype=torch.float), 1)]

        # Eager mode
        fake_qconfig = QConfig(activation=Observer, weight=WeightObserver)
        eager_module = TestM(fake_qconfig)
        # Script mode
        script_module = TestScriptM()
        script_module.conv.weight = torch.nn.Parameter(
            eager_module.conv.weight.detach())
        quantized_eager_module = quantize(eager_module, default_eval_fn, data)

        def get_forward(m):
            return m._c._get_method('forward')

        # TODO: test jit.script as well
        torch._C._jit_pass_constant_propagation(
            get_forward(script_module).graph)

        ScriptedObserver = torch.jit.script(Observer())
        ScriptedWeightObserver = torch.jit.script(WeightObserver())
        torch._C._jit_pass_prepare_quant(script_module._c, "forward",
                                         ScriptedObserver._c,
                                         ScriptedWeightObserver._c)
        # Run ScriptM Model and Collect statistics
        get_forward(script_module)(data[0][0])

        # Insert quantize and dequantize calls
        script_module._c = torch._C._jit_pass_insert_quant_dequant(
            script_module._c, "forward")
        # Note that observer modules are not removed right now
        torch._C._jit_pass_quant_fusion(
            script_module._c._get_method('forward').graph)
        get_forward(script_module)(data[0][0])
        eager_result = quantized_eager_module(data[0][0])
        script_result = get_forward(script_module)(data[0][0])
        self.assertEqual(eager_result, script_result)
Esempio n. 16
0
from __future__ import absolute_import, division, print_function, unicode_literals

import torch
import torch.nn as nn
from tests import utils
from torch.quantization import QConfig, observer

my_qconfig = QConfig(
    activation=observer.default_observer,
    weight=observer.HistogramObserver.with_args(dtype=torch.qint8,
                                                reduce_range=False),
)


class TestQuantizedBatchNorm3D(utils.TorchGlowTestCase):
    def test_batchnorm_basic(self):
        """
        Basic test of the PyTorch 3D batchnorm Node on Glow.
        """
        class SimpleQuantizedBatchNorm(nn.Module):
            def __init__(self, C, running_mean, running_var, scale,
                         zero_point):
                super(SimpleQuantizedBatchNorm, self).__init__()
                self.qconfig = my_qconfig
                self.batchnorm = nn.quantized.BatchNorm3d(C)
                self.batchnorm.scale = scale
                self.batchnorm.zero_point = zero_point
                self.batchnorm.running_mean = running_mean
                self.batchnorm.running_var = running_var
                self.relu = torch.nn.ReLU()
                self.dq = torch.nn.quantized.DeQuantize()
Esempio n. 17
0
    QConfig,
    float_qparams_weight_only_qconfig,
    get_default_qat_qconfig,
    get_default_qconfig,
)

# TensorFlow Lite Quantization Specs
# https://www.tensorflow.org/lite/performance/quantization_spec?hl=en
# For activations: int8 asymmetric per-tensor [-128, 127] range
# For weights: int8 symmetric per-tensor [-127, 127] range
_TFLITE_QCONFIG = QConfig(
    activation=MovingAverageMinMaxObserver.with_args(
        dtype=torch.qint8,
        quant_min=-128,
        quant_max=127,
        qscheme=torch.per_tensor_affine,
    ),
    weight=MinMaxObserver.with_args(dtype=torch.qint8,
                                    quant_min=-127,
                                    quant_max=127,
                                    qscheme=torch.per_tensor_symmetric),
)
_TFLITE_QAT_QCONFIG = QConfig(
    activation=FakeQuantize.with_args(
        observer=MovingAverageMinMaxObserver,
        dtype=torch.qint8,
        quant_min=-128,
        quant_max=127,
        qscheme=torch.per_tensor_affine,
    ),
    weight=FakeQuantize.with_args(observer=MinMaxObserver,
                                  dtype=torch.qint8,