Exemple #1
0
    def test_quantize_weight_clamping_per_channel(self):
        """ Test quant_{min, max} from per channel observer is honored by `_quantize_weight` method
        """
        fp_min, fp_max = -1000.0, 1000.0
        q8_min, q8_max = -10, 10

        float_tensor = torch.tensor([[fp_min, fp_max]])

        observer = MovingAveragePerChannelMinMaxObserver(
            averaging_constant=1.0,
            dtype=torch.qint8,
            quant_min=q8_min,
            quant_max=q8_max,
            qscheme=torch.per_channel_symmetric,
            ch_axis=0,
        )

        observer(float_tensor)
        assert observer.min_val == fp_min
        assert observer.max_val == fp_max

        quantized_tensor = _quantize_weight(float_tensor, observer)
        assert quantized_tensor.int_repr().max().item() == q8_max
        assert quantized_tensor.int_repr().min().item() == q8_min

        # Actual weight values can be outside than observer [min_val, max_val] for the moving average observer
        float_tensor *= 1.2

        quantized_tensor = _quantize_weight(float_tensor, observer)
        assert quantized_tensor.int_repr().max().item() == q8_max
        assert quantized_tensor.int_repr().min().item() == q8_min
Exemple #2
0
 def quantize_and_pack(w, b):
     weight_observer = weight_observer_method()
     weight_observer(w)
     qweight = _quantize_weight(w.float(), weight_observer)
     packed_weight = \
         torch.ops.quantized.linear_prepack(qweight, b)
     return packed_weight
Exemple #3
0
    def from_float(cls, mod):
        r"""Creates a quantized module from a float module or qparams_dict.
        Args:
            mod (Module): a float module, either produced by torch.quantization
              utilities or provided by the user
        """
        # derived classes override cls._FLOAT_MODULE attribute
        msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \
              cls._FLOAT_MODULE.__name__
        assert type(mod) == cls._FLOAT_MODULE, msg
        assert hasattr(mod, 'qconfig'), \
            'Input float module must have qconfig defined.'
        weight_post_process = mod.qconfig.weight()
        weight_post_process(mod.weight)
        act_scale, act_zp = mod.activation_post_process.calculate_qparams()
        assert weight_post_process.dtype == torch.qint8, \
            'Weight observer must have a dtype of qint8'
        qweight = _quantize_weight(mod.weight.float(), weight_post_process)
        # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd
        qconv = cls(
            mod.in_channels,
            mod.out_channels,
            mod.kernel_size,  # type: ignore[call-arg]
            mod.stride,
            mod.padding,
            mod.output_padding,
            mod.groups,
            mod.bias is not None,
            mod.dilation,
            mod.padding_mode)
        qconv.set_weight_bias(qweight, mod.bias)
        qconv.scale = float(act_scale)
        qconv.zero_point = int(act_zp)

        return qconv
Exemple #4
0
    def from_float(cls, mod):
        r"""Create a quantized module from a float module or qparams_dict

        Args:
            mod (Module): a float module, either produced by torch.quantization
                          utilities or provided by the user
        """
        if hasattr(mod, 'weight_fake_quant'):
            # assert type(mod) == QATLinear, 'training mode nnq.Linear.from_float only works for nn.qat.Linear'
            weight_post_process = mod.weight_fake_quant
            activation_post_process = mod.activation_post_process
        else:
            assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
                cls._FLOAT_MODULE.__name__
            assert hasattr(
                mod, 'qconfig'), 'Input float module must have qconfig defined'
            activation_post_process = mod.activation_post_process
            if type(mod) == nni.LinearReLU:
                mod = mod[0]
            weight_post_process = mod.qconfig.weight()
        weight_post_process(mod.weight)
        dtype = weight_post_process.dtype
        act_scale, act_zp = activation_post_process.calculate_qparams()
        assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
        qweight = _quantize_weight(mod.weight.float(), weight_post_process)
        qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
        qlinear.set_weight_bias(qweight, mod.bias)
        qlinear.scale = float(act_scale)
        qlinear.zero_point = int(act_zp)
        return qlinear
Exemple #5
0
    def from_float(cls, mod):
        r"""Create a dynamic quantized module from a float module or qparams_dict

        Args:
            mod (Module): a float module, either produced by torch.quantization
                          utilities or provided by the user
        """
        assert type(
            mod
        ) == NNLinear, 'nn.quantized.dynamic.Linear.from_float only works for nn.Linear'
        assert hasattr(
            mod, 'qconfig'), 'Input float module must have qconfig defined'
        if mod.qconfig is not None and mod.qconfig.weight is not None:
            weight_observer = mod.qconfig.weight()
        else:
            # We have the circular import issues if we import the qconfig in the beginning of this file:
            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
            # import until we need it.
            from torch.quantization.qconfig import default_dynamic_qconfig
            weight_observer = default_dynamic_qconfig.weight()
        dtype = weight_observer.dtype
        assert dtype in [
            torch.qint8, torch.float16
        ], 'The only supported dtypes for dynamic quantized linear are qint8 and float16'
        weight_observer(mod.weight)
        if dtype == torch.qint8:
            qweight = _quantize_weight(mod.weight.float(), weight_observer)
        elif dtype == torch.float16:
            qweight = mod.weight.float()
        else:
            raise RuntimeError(
                'Unsupported dtype specified for dynamic quantized Linear!')
        qlinear = Linear(mod.in_features, mod.out_features, dtype=dtype)
        qlinear.set_weight_bias(qweight, mod.bias)
        return qlinear
Exemple #6
0
    def from_float(cls, mod):
        r"""Create a quantized embedding module from a float module

        Args:
            mod (Module): a float module, either produced by torch.quantization
                          utilities or provided by user
        """
        assert type(mod) == nn.Embedding, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
            nn.Embedding.__name__
        assert hasattr(
            mod, 'qconfig'
        ), 'Embedding input float module must have qconfig defined'
        from torch.quantization import float_qparams_weight_only_qconfig
        if mod.qconfig is not None and mod.qconfig.weight is not None:
            weight_observer = mod.qconfig.weight()
        else:
            weight_observer = float_qparams_weight_only_qconfig.weight()

        dtype = weight_observer.dtype

        assert dtype == torch.quint8, 'The only supported dtype for nnq.Embedding is torch.quint8'

        # Run the observer to calculate qparams.
        weight_observer(mod.weight)
        qweight = _quantize_weight(mod.weight.float(), weight_observer)

        # Create quantized Embedding module and pass in the quantized weight
        qembedding = Embedding(mod.num_embeddings, mod.embedding_dim)
        qembedding.set_weight(qweight)
        return qembedding
Exemple #7
0
    def from_float(cls, mod):
        r"""Creates a quantized module from a float module or qparams_dict.

        Args:
            mod (Module): a float module, either produced by torch.quantization
              utilities or provided by the user
        """
        assert type(mod) == cls._FLOAT_MODULE, \
            ' nnq.' + cls.__name__ + '.from_float only works for ' + \
            cls._FLOAT_MODULE.__name__
        assert hasattr(mod, 'qconfig'), \
            'Input float module must have qconfig defined.'
        # Workaround for sequential, ConvReLU3d should probably inherit from
        # Conv3d instead
        if type(mod) == nni.ConvReLU3d:
            activation_post_process = mod[1].activation_post_process
            mod = mod[0]
        else:
            activation_post_process = mod.activation_post_process
        weight_post_process = mod.qconfig.weight()
        weight_post_process(mod.weight)
        act_scale, act_zp = activation_post_process.calculate_qparams()
        assert weight_post_process.dtype == torch.qint8, \
            'Weight observer must have a dtype of qint8'
        qweight = _quantize_weight(mod.weight.float(), weight_post_process)
        qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
                    mod.stride, mod.padding, mod.dilation, mod.groups,
                    mod.bias is not None, mod.padding_mode)
        qconv.set_weight_bias(qweight, mod.bias)
        qconv.scale = float(act_scale)
        qconv.zero_point = int(act_zp)

        return qconv
Exemple #8
0
 def get_qconv(cls, mod, activation_post_process, weight_post_process=None):
     r"""Creates a qconv object and returns it.
     """
     if weight_post_process is None:
         weight_post_process = mod.qconfig.weight()
     weight_post_process(mod.weight)
     act_scale, act_zp = activation_post_process.calculate_qparams()
     assert weight_post_process.dtype == torch.qint8, \
         'Weight observer must have a dtype of qint8'
     qweight = _quantize_weight(mod.weight.float(), weight_post_process)
     # the __init__ call used is the one from derived classes and not the one from _ConvNd
     qconv = cls(
         mod.in_channels,
         mod.out_channels,
         mod.kernel_size,  # type: ignore[call-arg]
         mod.stride,
         mod.padding,
         mod.dilation,
         mod.groups,
         mod.bias is not None,
         mod.padding_mode)
     qconv.set_weight_bias(qweight, mod.bias)
     qconv.scale = float(act_scale)
     qconv.zero_point = int(act_zp)
     return qconv
Exemple #9
0
    def from_float(cls, mod):
        r"""Create a quantized embedding_bag module from a float module

        Args:
            mod (Module): a float module, either produced by torch.ao.quantization
                          utilities or provided by user
        """
        if hasattr(mod, 'weight_fake_quant'):
            weight_observer = mod.weight_fake_quant
        else:
            assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
                nn.EmbeddingBag.__name__
            assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined'
            from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig
            if mod.qconfig is not None and mod.qconfig.weight is not None:
                weight_observer = mod.qconfig.weight()
            else:
                weight_observer = float_qparams_weight_only_qconfig.weight()

        dtype = weight_observer.dtype
        is_float_qparams_qconfig = weight_observer.qscheme == torch.per_channel_affine_float_qparams
        assert is_float_qparams_qconfig, \
            'EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig.'

        assert dtype == torch.quint8 or dtype == torch.quint4x2, \
            f'The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}'

        # Run the observer to calculate qparams.
        weight_observer(mod.weight)
        qweight = _quantize_weight(mod.weight.float(), weight_observer)

        # Create quantized EmbeddingBag module and pass in the quantized weight
        qembedding_bag = EmbeddingBag(mod.num_embeddings, mod.embedding_dim, dtype=dtype)
        qembedding_bag.set_weight(qweight)
        return qembedding_bag
Exemple #10
0
 def _observe_and_quantize_weight(weight):
     if dtype == torch.qint8:
         weight_observer = weight_observer_method()
         weight_observer(weight)
         qweight = _quantize_weight(weight.float(), weight_observer)
         return qweight
     else:
         return weight.float()
Exemple #11
0
    def from_float(cls, mod):
        r"""Create a quantized sparse module from a float module.

        We only care about the convert at this stage, no need for observers just yet.

        TODO(zaf): Need to add the sparse params to the qconfig
        """
        assert type(mod) == cls._FLOAT_MODULE, cls._get_name() + \
            '.from_float only works for ' + cls._FLOAT_MODULE.__name__
        assert hasattr(mod, 'sparse_params'), \
            ('Expecting the Linear to have `sparse_params`. Make sure you have provided arguments '
             'in the `sparsifier.squash_mask(params_to_save=("sparse_block_shape",))` method.')
        sparse_block_shape = mod.sparse_params.get(
            'sparse_block_shape', None)  # type: ignore[operator, union-attr]
        assert isinstance(sparse_block_shape, (tuple, list))
        assert len(sparse_block_shape) == 2
        # TODO: Need to add options to qconfig to avoid the calibration.
        # TODO: Add calibration for the sparsity
        assert hasattr(
            mod, 'qconfig'), 'Input float module must have qconfig defined'
        activation_post_process = mod.activation_post_process
        weight_post_process = mod.qconfig.weight(
        )  # type: ignore[operator, union-attr]

        # Assumption is that the weight is already sparsified by the
        # `sparsifier.convert`
        weight = mod.weight

        weight_post_process(weight)
        dtype = weight_post_process.dtype
        act_scale, act_zp = activation_post_process.calculate_qparams(
        )  # type: ignore[operator, union-attr]
        assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
        w_sc, w_zp = weight_post_process.calculate_qparams()
        if isinstance(w_zp, torch.Tensor):
            assert not torch.any(
                w_zp.bool()), "All weight zero points must map to 0"
        else:
            assert w_zp == 0, 'Weight zero point must map to 0'
        qweight = _quantize_weight(weight.float(), weight_post_process)

        row_block_size = mod.sparse_params['sparse_block_shape'][
            0]  # type: ignore[index]
        col_block_size = mod.sparse_params['sparse_block_shape'][
            1]  # type: ignore[index]
        qlinear = cls(mod.in_features,
                      mod.out_features,
                      row_block_size,
                      col_block_size,
                      dtype=dtype)
        qlinear.set_weight_bias(qweight, mod.bias, row_block_size,
                                col_block_size)  # type: ignore[arg-type]
        qlinear.scale = float(act_scale)
        qlinear.zero_point = int(act_zp)
        return qlinear
Exemple #12
0
    def from_float(cls, mod):
        r"""Create a quantized sparse dynamic module from a float module.

        We only care about the convert at this stage, no need for observers just yet.
        """
        assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
            cls._FLOAT_MODULE.__name__
        # TODO: Need to add options to qconfig to avoid the calibration.
        # TODO: Add calibration for the sparsity
        assert hasattr(
            mod, 'qconfig'), 'Input float module must have qconfig defined'
        if type(mod) == nni.LinearReLU:
            mod = mod[0]
        if mod.qconfig is not None and mod.qconfig.weight is not None:
            weight_observer = mod.qconfig.weight()
        else:
            # We have the circular import issues if we import the qconfig in the beginning of this file:
            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
            # import until we need it.
            from torch.quantization.qconfig import default_dynamic_qconfig
            weight_observer = default_dynamic_qconfig.weight()

        # It is important to multiply by the mask BEFORE calling the `weight_observer`
        # TODO (zaf): Mask might not be part of the qconfig (T83295194)
        weight = mod.weight
        if getattr(mod.qconfig, 'mask', False):
            weight = mod.qconfig.mask * mod.weight

        weight_observer(weight)
        dtype = weight_observer.dtype
        assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
        w_sc, w_zp = weight_observer.calculate_qparams()
        if isinstance(w_zp, torch.Tensor):
            assert not torch.any(
                w_zp.bool()), "All weight zero points must map to 0"
        else:
            assert w_zp == 0, 'Weight zero point must map to 0'
        qweight = _quantize_weight(weight.float(), weight_observer)

        # Use these default values until we figure out how to augment
        # `mod` to contain sparse config
        row_block_size, col_block_size = QNNPACKLinearBlockSparsePattern.block_size(
        )
        qlinear = cls(mod.in_features,
                      mod.out_features,
                      row_block_size,
                      col_block_size,
                      dtype=dtype)
        qlinear.set_weight_bias(qweight, mod.bias, row_block_size,
                                col_block_size)
        return qlinear
Exemple #13
0
    def from_float(cls, mod):
        r"""Create a quantized sparse module from a float module.

        We only care about the convert at this stage, no need for observers just yet.
        """
        assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
            cls._FLOAT_MODULE.__name__
        # TODO: Need to add options to qconfig to avoid the calibration.
        # TODO: Add calibration for the sparsity
        assert hasattr(
            mod, 'qconfig'), 'Input float module must have qconfig defined'
        activation_post_process = mod.activation_post_process
        if type(mod) == nni.LinearReLU:
            mod = mod[0]
        weight_post_process = mod.qconfig.weight()

        # It is important to multiply by the mask BEFORE calling the `weight_post_process`
        # TODO (zaf): Mask might not be part of the qconfig (T83295194)
        weight = mod.weight
        if getattr(mod.qconfig, 'mask', False):
            weight = mod.qconfig.mask * mod.weight

        weight_post_process(weight)
        dtype = weight_post_process.dtype
        act_scale, act_zp = activation_post_process.calculate_qparams()
        assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
        w_sc, w_zp = weight_post_process.calculate_qparams()
        if isinstance(w_zp, torch.Tensor):
            assert not torch.any(
                w_zp.bool()), "All weight zero points must map to 0"
        else:
            assert w_zp == 0, 'Weight zero point must map to 0'
        qweight = _quantize_weight(weight.float(), weight_post_process)

        # Use these default values until we figure out how to augment
        # `mod` to contain sparse config
        row_block_size = 1
        col_block_size = 4
        qlinear = cls(mod.in_features,
                      mod.out_features,
                      row_block_size,
                      col_block_size,
                      dtype=dtype)
        qlinear.set_weight_bias(qweight, mod.bias, row_block_size,
                                col_block_size)
        qlinear.scale = float(act_scale)
        qlinear.zero_point = int(act_zp)
        return qlinear
Exemple #14
0
    def from_float(cls, mod):
        r"""Create a quantized module from an observed float module

        Args:
            mod (Module): a float module, either produced by torch.ao.quantization
                          utilities or provided by the user
        """
        if hasattr(mod, 'weight_fake_quant'):
            if type_before_parametrizations(mod) == nniqat.LinearBn1d:
                mod.weight, mod.bias = fuse_linear_bn_weights(
                    mod.weight, mod.bias, mod.bn.running_mean,
                    mod.bn.running_var, mod.bn.eps, mod.bn.weight, mod.bn.bias)
            weight_post_process = mod.weight_fake_quant
            activation_post_process = mod.activation_post_process
        else:
            # This function does not participate in JIT, so it is OK to ignore
            # the type mismatch in assignment. Also, mypy has an issue with
            # iterables not being implemented, so we are ignoring those too.
            if not isinstance(cls._FLOAT_MODULE, Iterable):
                cls._FLOAT_MODULE = [cls._FLOAT_MODULE
                                     ]  # type: ignore[assignment]
            supported_modules = ', '.join([
                float_mod.__name__ for float_mod in cls._FLOAT_MODULE
            ])  # type: ignore[attr-defined]
            error_msg = 'nnq.{}.from_float only works for {}, but got: {}'.format(
                cls.__name__, supported_modules, type(mod))
            assert type_before_parametrizations(
                mod) in cls._FLOAT_MODULE, error_msg.format(
                )  # type: ignore[attr-defined]
            assert hasattr(
                mod, 'qconfig'), 'Input float module must have qconfig defined'
            activation_post_process = mod.activation_post_process
            if type_before_parametrizations(mod) == nni.LinearReLU:
                mod = mod[0]
            weight_post_process = mod.qconfig.weight()
        weight_post_process(mod.weight)
        dtype = weight_post_process.dtype
        act_scale, act_zp = activation_post_process.calculate_qparams()
        assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
        qweight = _quantize_weight(mod.weight.float(), weight_post_process)
        qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
        qlinear.set_weight_bias(qweight, mod.bias)
        qlinear.scale = float(act_scale)
        qlinear.zero_point = int(act_zp)
        return qlinear
Exemple #15
0
    def from_float(cls, mod):
        r"""Creates a quantized module from a float module or qparams_dict.

        Args:
            mod (Module): a float module, either produced by torch.quantization
                          utilities or provided by the user
        """
        if hasattr(mod, 'weight_fake_quant'):
            # assert type(mod) == cls.__QAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
            #     cls.__QAT_MODULE.__name__
            if type(mod) == nniqat.ConvBn2d:
                mod.weight, mod.bias = \
                    fuse_conv_bn_weights(mod.weight, mod.bias, mod.running_mean,
                                         mod.running_var, mod.eps, mod.gamma, mod.beta)
            assert hasattr(
                mod,
                'observer'), 'Input QAT module must have observer attached'
            weight_observer = mod.weight_fake_quant
            activation_observer = mod.observer
        else:
            assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
                cls._FLOAT_MODULE.__name__
            assert hasattr(
                mod, 'qconfig'), 'Input float module must have qconfig defined'
            # workaround for sequential, ConvReLU2d should probably
            # inherit from Conv2d instead
            if type(mod) == nni.ConvReLU2d:
                activation_observer = mod[1].observer
                mod = mod[0]
            else:
                activation_observer = mod.observer
            weight_observer = mod.qconfig.weight()
            weight_observer(mod.weight)
        act_scale, act_zp = activation_observer.calculate_qparams()
        assert weight_observer.dtype == torch.qint8, 'Weight observer must have a dtype of qint8'
        qweight = _quantize_weight(mod.weight.float(), weight_observer)
        qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
                    mod.stride, mod.padding, mod.dilation, mod.groups, mod.bias
                    is not None, mod.padding_mode)
        qconv.set_weight_bias(qweight, mod.bias)
        qconv.scale = float(act_scale)
        qconv.zero_point = int(act_zp)

        return qconv
Exemple #16
0
    def from_float(cls, mod):
        r"""Create a dynamic quantized module from a float module or qparams_dict

        Args:
            mod (Module): a float module, either produced by torch.ao.quantization
                          utilities or provided by the user
        """
        float_modules = [
            torch.nn.Linear,
            torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
            torch.nn.intrinsic.modules.fused.LinearReLU,
            torch.nn.qat.dynamic.Linear
        ]

        assert type(mod) in float_modules, \
            'nn.quantized.dynamic.Linear.from_float only works for one of' + \
            str([float_mod.__name__ for float_mod in float_modules])
        assert hasattr(
            mod, 'qconfig'), 'Input float module must have qconfig defined'
        if type(mod) == nni.LinearReLU:
            mod = mod[0]
        if mod.qconfig is not None and mod.qconfig.weight is not None:
            weight_observer = mod.qconfig.weight()
        else:
            # We have the circular import issues if we import the qconfig in the beginning of this file:
            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
            # import until we need it.
            from torch.ao.quantization.qconfig import default_dynamic_qconfig
            weight_observer = default_dynamic_qconfig.weight()
        dtype = weight_observer.dtype
        assert dtype in [torch.qint8, torch.float16], "The only supported dtypes for " \
            "dynamic quantized linear are qint8 and float16 got: {}".format(dtype)
        weight_observer(mod.weight)
        if dtype == torch.qint8:
            qweight = _quantize_weight(mod.weight.float(), weight_observer)
        elif dtype == torch.float16:
            qweight = mod.weight.float()
        else:
            raise RuntimeError(
                'Unsupported dtype specified for dynamic quantized Linear!')
        qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
        qlinear.set_weight_bias(qweight, mod.bias)
        return qlinear
Exemple #17
0
    def from_float(cls, mod):
        r"""Create a quantized sparse module from a float module.

        We only care about the convert at this stage, no need for observers just yet.

        TODO: Need to figure out how to store the block shapes in the mod
        """
        assert type(mod) == cls._FLOAT_MODULE, cls._get_name() + \
            '.from_float only works for ' + cls._FLOAT_MODULE.__name__
        # TODO: Need to add options to qconfig to avoid the calibration.
        # TODO: Add calibration for the sparsity
        assert hasattr(
            mod, 'qconfig'), 'Input float module must have qconfig defined'
        activation_post_process = mod.activation_post_process
        weight_post_process = mod.qconfig.weight()

        # Assumption is that the weight is already sparsified by the
        # `sparsifier.convert`
        weight = mod.weight

        weight_post_process(weight)
        dtype = weight_post_process.dtype
        act_scale, act_zp = activation_post_process.calculate_qparams()
        assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
        w_sc, w_zp = weight_post_process.calculate_qparams()
        if isinstance(w_zp, torch.Tensor):
            assert not torch.any(
                w_zp.bool()), "All weight zero points must map to 0"
        else:
            assert w_zp == 0, 'Weight zero point must map to 0'
        qweight = _quantize_weight(weight.float(), weight_post_process)

        row_block_size, col_block_size = LinearBlockSparsePattern.block_size()
        qlinear = cls(mod.in_features,
                      mod.out_features,
                      row_block_size,
                      col_block_size,
                      dtype=dtype)
        qlinear.set_weight_bias(qweight, mod.bias, row_block_size,
                                col_block_size)
        qlinear.scale = float(act_scale)
        qlinear.zero_point = int(act_zp)
        return qlinear
Exemple #18
0
        def process_weights(weight, bias, dtype):

            if dtype == torch.qint8:
                # for each layer, for each direction we need to quantize and pack
                # weights and pack parameters in this order:
                #
                #   w_ih, w_hh
                weight_observer = weight_observer_method()
                weight_observer(weight)
                qweight = _quantize_weight(weight.float(), weight_observer)
                packed_weight = \
                    torch.ops.quantized.linear_prepack(qweight, bias)

                return packed_weight
            else:
                # for each layer, for each direction we need to quantize and pack
                # weights and pack parameters in this order:
                #
                #   packed_ih, packed_hh, b_ih, b_hh
                packed_weight = torch.ops.quantized.linear_prepack_fp16(
                    weight.float(), bias)

                return packed_weight
Exemple #19
0
    def from_float(cls, mod):
        r"""Create a quantized module from a float module or qparams_dict

        Args:
            mod (Module): a float module, either produced by torch.quantization
                          utilities or provided by the user
        """
        if hasattr(mod, 'weight_fake_quant'):
            # assert type(mod) == QATLinear, 'training mode nnq.Linear.from_float only works for nn.qat.Linear'
            weight_observer = mod.weight_fake_quant
            activation_observer = mod.observer
        else:
            assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
                cls._FLOAT_MODULE.__name__
            assert hasattr(
                mod, 'qconfig'), 'Input float module must have qconfig defined'
            assert hasattr(
                mod,
                'observer'), 'Input float module must have observer attached'
            # workaround for sequential, ConvReLU2d should probably
            # inherit from Conv2d instead
            if type(mod) == nni.LinearReLU:
                activation_observer = mod[1].observer
                mod = mod[0]
            else:
                activation_observer = mod.observer
            weight_observer = mod.qconfig.weight()
            weight_observer(mod.weight)
        act_scale, act_zp = activation_observer.calculate_qparams()
        assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
        qweight = _quantize_weight(mod.weight.float(), weight_observer)
        qlinear = cls(mod.in_features, mod.out_features)
        qlinear.set_weight_bias(qweight, mod.bias)
        qlinear.scale = float(act_scale)
        qlinear.zero_point = int(act_zp)
        return qlinear