Exemplo n.º 1
0
class BaseQuantizer(nn.Module):
    # pylint:disable=too-many-public-methods
    def __init__(self, qspec: PTQuantizerSpec):
        super().__init__()
        self._narrow_range = qspec.narrow_range
        self._signedness_to_force = qspec.signedness_to_force
        self._is_using_log_scale_storage = qspec.logarithm_scale
        self._half_range = qspec.half_range
        self._num_bits = CompressionParameter(
            torch.IntTensor([qspec.num_bits]),
            requires_grad=False,
            compression_lr_multiplier=qspec.compression_lr_multiplier)
        OPTIONAL_PARAMETERS_REGISTRY.register('_num_bits')
        self.level_high = None
        self.level_low = None

        self.levels = 0
        ENABLED_VAR_NAME = 'enabled'
        self.register_buffer(ENABLED_VAR_NAME, torch.IntTensor([1]))
        OPTIONAL_PARAMETERS_REGISTRY.register(ENABLED_VAR_NAME)
        self.initialized = False
        self.call_count = 0
        self._scale_shape = qspec.scale_shape
        self._export_mode = QuantizerExportMode.FAKE_QUANTIZE

        class LoadStateListener:
            """
               Check whether a quantization module are going to be updated by new values from state_dict or checkpoint.
            """
            def __init__(self, module):
                # pylint: disable=protected-access
                self.hook = module._register_load_state_dict_pre_hook(
                    partial(self.hook_fn, module=module))

            def hook_fn(self, state_dict, prefix, local_metadata, strict,
                        missing_keys, unexpected_keys, error_msgs, module):
                for module_key in module.state_dict().keys():
                    candidate = prefix + module_key
                    if candidate in state_dict:
                        module.initialized = True

            def close(self):
                self.hook.remove()

        self.load_listener = LoadStateListener(self)

    def enable_gradients(self):
        raise NotImplementedError

    def disable_gradients(self):
        raise NotImplementedError

    def is_enabled_quantization(self):
        with no_jit_trace():
            return self.enabled[0].item() == 1

    def enable_quantization(self):
        self.enabled[0] = 1
        self.enable_gradients()

    def disable_quantization(self):
        self.enabled[0] = 0
        self.disable_gradients()

    def forward(self, x):
        if is_debug():
            self.call_count += 1
        # TODO: refactor to get rid of extra if's and calls on each forward
        if not self.is_enabled_quantization():
            return x
        self.set_level_ranges()
        is_exporting = is_tracing_state()
        if is_exporting:
            with no_nncf_trace():
                x = self.run_export_quantization(x)

            # The underlying operator (registered via register_operator) must be executed,
            # otherwise the dynamic graph won't be traced as it was during regular inference.
            # While this does not impact the regular, non-RNN models, for which the graph
            # building and pre-/post-hook calling is only determined by input-agnostic,
            # graph-structure independent trace info (e.g. current op scope and call count),
            # this is important for LSTMs etc. where determining the "first nodes in iteration
            # scopes" depends on whether the input tensors to an operation were traced or not.
            return self.quantize(x, execute_traced_op_as_identity=True)

        return self.quantize(x, execute_traced_op_as_identity=False)

    def quantize(self, x, execute_traced_op_as_identity: bool = False):
        raise NotImplementedError

    def reset_call_counter(self):
        self.call_count = 0

    def get_trainable_params(self) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    def apply_minmax_init(self,
                          min_values,
                          max_values,
                          log_module_name: str = None):
        """min_values and max_values must have the same shape as specified in self.scale_shape"""
        if self.initialized:
            nncf_logger.debug(
                "Skipped initializing {} - loaded from checkpoint".format(
                    log_module_name))
            return
        if torch.any(torch.eq(min_values, np.inf)) or torch.any(
                torch.eq(max_values, -np.inf)):
            raise AttributeError(
                'Statistics is not collected for {}'.format(log_module_name))
        own_device = next(self.parameters()).device
        min_values = min_values.to(own_device)
        max_values = max_values.to(own_device)
        self._apply_minmax_init(min_values, max_values, log_module_name)

    def _apply_minmax_init(self,
                           min_values,
                           max_values,
                           log_module_name: str = None):
        raise NotImplementedError

    def set_level_ranges(self):
        raise NotImplementedError

    @property
    def is_using_log_scale_storage(self):
        return self._is_using_log_scale_storage

    @property
    def signed(self):
        raise NotImplementedError

    @property
    def num_bits(self):
        with no_jit_trace():
            return self._num_bits.item()

    @num_bits.setter
    def num_bits(self, num_bits: int):
        self._num_bits.fill_(num_bits)

    @property
    def narrow_range(self) -> bool:
        return self._narrow_range

    @property
    def scale_shape(self) -> Tuple[int, ...]:
        # Per-tensor scale shapes are (1,)
        return self._scale_shape

    def broadcast_initialized_params(self, src: int = 0):
        distributed.broadcast(self._num_bits, src=src)

    def set_export_mode(self, mode: QuantizerExportMode):
        self._export_mode = mode

    def _get_input_low_input_high(self):
        raise NotImplementedError

    def _prepare_export_quantization(self, x: torch.Tensor):
        raise NotImplementedError

    def _prepare_fq_export_quantization(self, x: torch.Tensor):
        x, level_high, level_low, input_low, input_high = self._prepare_export_quantization(
            x)
        with no_jit_trace():
            levels = level_high - level_low + 1
        return x, levels, input_low, input_high

    def _prepare_qdq_export_quantization(self, x: torch.Tensor):
        x, level_high, level_low, input_low, input_high = self._prepare_export_quantization(
            x)
        with no_jit_trace():
            y_scale, y_zero_point = get_scale_zp_from_input_low_input_high(
                level_low, level_high, input_low, input_high)
        return x, y_scale, y_zero_point

    def run_export_quantization(self, x: torch.Tensor):
        if self._export_mode == QuantizerExportMode.FAKE_QUANTIZE:
            x, levels, input_low, input_high = self._prepare_fq_export_quantization(
                x)
            return ExportQuantizeToFakeQuantize.apply(x, levels, input_low,
                                                      input_high, input_low,
                                                      input_high)
        if self._export_mode == QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS:
            x, y_scale, y_zero_point = self._prepare_qdq_export_quantization(x)
            if self.per_channel and y_zero_point.numel() > 1:
                if torch.allclose(y_scale - y_scale[0], torch.zeros_like(y_scale)) and \
                        torch.allclose(y_zero_point - y_zero_point[0], torch.zeros_like(y_zero_point)):
                    y_scale, y_zero_point = y_scale[0], y_zero_point[0]
                    return ExportQuantizeToONNXQuantDequant.apply(
                        x, y_scale, y_zero_point)
                raise RuntimeError(
                    "PyTorch export to ONNX using QuantizeLinear-DequantizeLinear "
                    "doesn't support per channel quantization")
            return ExportQuantizeToONNXQuantDequant.apply(
                x, y_scale, y_zero_point)
        raise RuntimeError('Unknown export mode')

    def extra_repr(self):
        return 'bit={}, ch={}'.format(self.num_bits, self.per_channel)

    def get_quantizer_config(self) -> QuantizerConfig:
        raise NotImplementedError

    @property
    def per_channel(self) -> bool:
        numel = 1
        for el in self.scale_shape:
            numel *= el
        is_per_tensor = ((numel == 1) and (len(self.scale_shape) == 1))
        return not is_per_tensor
Exemplo n.º 2
0
class SymmetricQuantizer(BaseQuantizer):
    SCALE_PARAM_NAME = 'scale'
    _SCALE_PARAM_STORAGE_ATTR = '_scale_param_storage'

    def __init__(self, qspec: PTQuantizerSpec):
        super().__init__(qspec)
        self.signed_tensor = CompressionParameter(
            torch.IntTensor([0]),
            requires_grad=False,
            compression_lr_multiplier=qspec.compression_lr_multiplier)
        self.collect_scale_statistics = False

        setattr(
            self, self._SCALE_PARAM_STORAGE_ATTR,
            CompressionParameter(
                torch.ones(self.scale_shape),
                requires_grad=True,
                compression_lr_multiplier=qspec.compression_lr_multiplier))
        if self._is_using_log_scale_storage:
            self._scale_param_storage.data.log_()
            self.eps = 0
        else:
            self.eps = 1e-16
        if qspec.signedness_to_force is not None:
            self.signed = int(qspec.signedness_to_force)
        self.set_level_ranges()

        self._register_load_state_dict_pre_hook(
            StorageRedirectingLoadStateDictHook(
                storage_attribute_in_module=self._SCALE_PARAM_STORAGE_ATTR,
                name_in_state_dict=self.SCALE_PARAM_NAME,
                use_log_storage_in_module=self._is_using_log_scale_storage))

        self._register_state_dict_hook(
            StorageRedirectingStateDictHook(
                storage_attribute_in_module=self._SCALE_PARAM_STORAGE_ATTR,
                name_in_state_dict=self.SCALE_PARAM_NAME,
                use_log_storage_in_module=self._is_using_log_scale_storage))

    @property
    def scale(self):
        return self._scale_param_storage.exp(
        ) if self._is_using_log_scale_storage else self._scale_param_storage

    @scale.setter
    def scale(self, v):
        self._scale_param_storage = v
        if self._is_using_log_scale_storage:
            self._scale_param_storage.data.log_()

    def __setattr__(self, key, value):
        """
        Need to handle the redirect-storage attributes (which are implemented using Python properties
        here) specially - otherwise the torch.nn.Module's __setattr__ will try to set them during
        assignment.
        """
        if key == self.SCALE_PARAM_NAME:
            object.__setattr__(self, key, value)
        else:
            super().__setattr__(key, value)

    def enable_gradients(self):
        self._scale_param_storage.requires_grad = True

    def disable_gradients(self):
        self._scale_param_storage.requires_grad = False

    def set_level_ranges(self):
        scaled_num_bits = 1 if self._half_range else 0
        self.level_low, self.level_high, self.levels = self.calculate_level_ranges(
            self.num_bits - scaled_num_bits, self.signed)

    @staticmethod
    def calculate_level_ranges(num_bits, signed):
        return calculate_symmetric_level_ranges(num_bits, signed)

    @property
    def signed(self):
        with no_jit_trace():
            return self.signed_tensor.item() == 1

    @signed.setter
    def signed(self, signed: bool):
        self.signed_tensor.fill_(signed)

    def quantize(self, x, execute_traced_op_as_identity: bool = False):
        return symmetric_quantize(x,
                                  self.levels,
                                  self.level_low,
                                  self.level_high,
                                  self.scale,
                                  self.eps,
                                  skip=execute_traced_op_as_identity)

    def get_trainable_params(self) -> Dict[str, torch.Tensor]:
        return {self.SCALE_PARAM_NAME: self.scale.detach()}

    def _apply_minmax_init(self,
                           min_values,
                           max_values,
                           log_module_name: str = None):
        if torch.any(torch.eq(min_values, np.inf)) or torch.any(
                torch.eq(max_values, -np.inf)):
            raise AttributeError(
                'Statistics is not collected for {}'.format(log_module_name))
        sign = torch.any(torch.lt(min_values, 0))
        if self._signedness_to_force is not None and sign != self._signedness_to_force:
            nncf_logger.warning("Forcing signed to {} for module {}".format(
                self._signedness_to_force, log_module_name))
            sign = self._signedness_to_force
        self.signed = int(sign)

        abs_max = torch.max(torch.abs(max_values), torch.abs(min_values))
        SCALE_LOWER_THRESHOLD = 0.1
        mask = torch.gt(abs_max, SCALE_LOWER_THRESHOLD)
        self._scale_param_storage.data = torch.where(
            mask, abs_max,
            SCALE_LOWER_THRESHOLD * torch.ones_like(self._scale_param_storage))
        if self._is_using_log_scale_storage:
            self._scale_param_storage.data.log_()

        nncf_logger.info("Set sign: {} and scale: {} for {}".format(
            self.signed, get_flat_tensor_contents_string(self.scale),
            log_module_name))

    def broadcast_initialized_params(self, src: int = 0):
        super().broadcast_initialized_params(src)
        distributed.broadcast(self._scale_param_storage, src=src)
        distributed.broadcast(self.signed_tensor, src=src)

    def _get_input_low_input_high(self, scale, level_low, level_high, eps):
        input_range = abs(scale) + eps
        input_low = input_range * level_low / level_high
        input_high = input_range
        return input_low, input_high

    def _prepare_export_quantization(self, x: torch.Tensor):
        with no_jit_trace():
            input_low, input_high = self._get_input_low_input_high(
                self.scale, self.level_low, self.level_high, self.eps)
            level_low = self.level_low
            level_high = self.level_high
            if self._half_range:
                x = torch.min(torch.max(x, input_low), input_high)
                level_low = 2 * self.level_low
                level_high = 2 * self.level_high + 1
                input_low, input_high = self._get_input_low_input_high(
                    level_high / self.level_high * self.scale, level_low,
                    level_high, self.eps)
        return x, level_high, level_low, input_low, input_high

    def get_quantizer_config(self) -> QuantizerConfig:
        return QuantizerConfig(num_bits=self.num_bits,
                               mode=QuantizationMode.SYMMETRIC,
                               signedness_to_force=self.signed,
                               per_channel=self.per_channel)