Ejemplo n.º 1
0
class QuantConvTranspose1d(QuantLayer, ConvTranspose1d):
    """

        Parameters
        ----------

        %(weight_quant_proxy.parameters_with_prefix)s
    """
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: Union[int, Tuple[int]],
            stride: Union[int, Tuple[int]] = 1,
            padding: Union[int, Tuple[int]] = 0,
            output_padding: Union[int, Tuple[int]] = 0,
            padding_type: PaddingType = PaddingType.STANDARD,
            dilation: Union[int, Tuple[int]] = 1,
            groups: int = 1,
            bias: bool = True,
            bias_quant_type: QuantType = QuantType.FP,
            bias_narrow_range: bool = False,
            bias_bit_width: int = None,
            weight_quant_override: WeightQuantProxy = None,
            weight_quant_type: QuantType = QuantType.FP,
            weight_narrow_range: bool = False,
            weight_scaling_override: Optional[Module] = None,
            weight_bit_width_impl_override: Union[BitWidthParameter,
                                                  BitWidthConst] = None,
            weight_bit_width_impl_type: BitWidthImplType = BitWidthImplType.
        CONST,
            weight_restrict_bit_width_type:
        RestrictValueType = RestrictValueType.INT,
            weight_bit_width: int = 32,
            weight_min_overall_bit_width: Optional[int] = 2,
            weight_max_overall_bit_width: Optional[int] = None,
            weight_scaling_impl_type: ScalingImplType = ScalingImplType.STATS,
            weight_scaling_const: Optional[float] = None,
            weight_scaling_stats_op: StatsOp = StatsOp.MAX,
            weight_scaling_per_output_channel: bool = False,
            weight_ternary_threshold: float = 0.5,
            weight_restrict_scaling_type: RestrictValueType = RestrictValueType
        .LOG_FP,
            weight_scaling_stats_sigma: float = 3.0,
            weight_scaling_min_val: float = SCALING_MIN_VAL,
            weight_override_pretrained_bit_width: bool = False,
            compute_output_scale: bool = False,
            compute_output_bit_width: bool = False,
            return_quant_tensor: bool = False,
            deterministic: bool = False) -> None:
        QuantLayer.__init__(self,
                            compute_output_scale=compute_output_scale,
                            compute_output_bit_width=compute_output_bit_width,
                            return_quant_tensor=return_quant_tensor)
        ConvTranspose1d.__init__(self,
                                 in_channels=in_channels,
                                 out_channels=out_channels,
                                 kernel_size=kernel_size,
                                 stride=stride,
                                 padding=padding,
                                 output_padding=output_padding,
                                 dilation=dilation,
                                 groups=groups,
                                 bias=bias)
        if weight_quant_type == QuantType.FP and compute_output_bit_width:
            raise Exception(
                "Computing output bit width requires enabling quantization")
        if bias_quant_type != QuantType.FP and not (compute_output_scale and
                                                    compute_output_bit_width):
            raise Exception(
                "Quantizing bias requires to compute output scale and output bit width"
            )

        if torch.backends.cudnn.benchmark:
            torch.backends.cudnn.deterministic = deterministic

        # self.per_elem_ops = 2 * self.kernel_size[0] * (in_channels // groups) # TO DO: Implement op_count
        self.padding_type = padding_type
        self.weight_reg = WeightReg()

        if weight_quant_override is not None:
            self.weight_quant = weight_quant_override
            self.weight_quant.add_tracked_parameter(self.weight)
        else:
            weight_scaling_stats_input_concat_dim = 1
            if weight_scaling_per_output_channel:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                weight_scaling_shape = self.per_output_channel_broadcastable_shape
                weight_scaling_stats_reduce_dim = 1
            else:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_TENSOR
                weight_scaling_shape = SCALING_SCALAR_SHAPE
                weight_scaling_stats_reduce_dim = None

            if weight_scaling_stats_op == StatsOp.MAX_AVE:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                weight_scaling_stats_reduce_dim = 1

            self.weight_quant = WeightQuantProxy(
                bit_width=weight_bit_width,
                quant_type=weight_quant_type,
                narrow_range=weight_narrow_range,
                scaling_override=weight_scaling_override,
                restrict_scaling_type=weight_restrict_scaling_type,
                scaling_const=weight_scaling_const,
                scaling_stats_op=weight_scaling_stats_op,
                scaling_impl_type=weight_scaling_impl_type,
                scaling_stats_reduce_dim=weight_scaling_stats_reduce_dim,
                scaling_shape=weight_scaling_shape,
                bit_width_impl_type=weight_bit_width_impl_type,
                bit_width_impl_override=weight_bit_width_impl_override,
                restrict_bit_width_type=weight_restrict_bit_width_type,
                min_overall_bit_width=weight_min_overall_bit_width,
                max_overall_bit_width=weight_max_overall_bit_width,
                tracked_parameter_list_init=self.weight,
                ternary_threshold=weight_ternary_threshold,
                scaling_stats_input_view_shape_impl=
                weight_stats_input_view_shape_impl,
                scaling_stats_input_concat_dim=
                weight_scaling_stats_input_concat_dim,
                scaling_stats_sigma=weight_scaling_stats_sigma,
                scaling_min_val=weight_scaling_min_val,
                override_pretrained_bit_width=
                weight_override_pretrained_bit_width)
        self.bias_quant = BiasQuantProxy(quant_type=bias_quant_type,
                                         bit_width=bias_bit_width,
                                         narrow_range=bias_narrow_range)

    @property
    def per_output_channel_broadcastable_shape(self):
        if self.transposed:
            raise Exception("Transposed filters are not supported.")
        else:
            output_dim = 0
        per_channel_size = [1] * len(self.weight.size())
        per_channel_size[output_dim] = self.out_channels
        per_channel_size = tuple(per_channel_size)
        return per_channel_size

    @property
    def int_weight(self):
        if isinstance(self.weight_quant.tensor_quant, IdentityQuant):
            raise Exception(
                "Can't export int weight without quantization enabled")
        return self.weight_quant.int_weight(self.weight)

    @property
    def quant_weight_scale(self):
        """

        Returns scale factor of the quantized weights with scalar () shape or (self.out_channels, 1, 1)
        shape depending on whether scaling is per layer or per-channel.
        -------

        """
        if isinstance(self.weight_quant.tensor_quant, IdentityQuant):
            raise Exception(
                "Can't generate scaling factor without quantization enabled")
        zero_hw_sentinel = self.weight_quant.zero_hw_sentinel
        _, scale, _ = self.weight_quant.tensor_quant(self.weight,
                                                     zero_hw_sentinel)
        return scale

    def forward(self, input, output_size=None):
        output_scale = None
        output_bit_width = None
        quant_bias_bit_width = None

        input, input_scale, input_bit_width = self.unpack_input(input)
        quant_weight, quant_weight_scale, quant_weight_bit_width = self.weight_quant(
            self.weight)
        quant_weight = self.weight_reg(quant_weight)

        if self.compute_output_bit_width:
            assert input_bit_width is not None
            output_bit_width = self.max_output_bit_width(
                input_bit_width, quant_weight_bit_width)
        if self.compute_output_scale:
            assert input_scale is not None
            output_scale = input_scale * quant_weight_scale

        output_padding = self.compute_output_padding(input, output_size)

        if self.bias is not None:
            quant_bias, _, quant_bias_bit_width = self.bias_quant(
                self.bias, output_scale, output_bit_width)
            output = self.conv_transpose1d(input, quant_weight, quant_bias,
                                           output_padding)
        else:
            output = self.conv_transpose1d(input, quant_weight, None,
                                           output_padding)

        if self.compute_output_bit_width and quant_bias_bit_width is not None:
            output_bit_width = torch.where(
                quant_bias_bit_width > output_bit_width, quant_bias_bit_width,
                output_bit_width)
            output_bit_width = output_bit_width + 1

        return self.pack_output(output, output_scale, output_bit_width)

    def compute_output_padding(self, input, output_size):
        return self._output_padding(input, output_size, self.stride,
                                    self.padding, self.kernel_size)

    def conv_transpose1d(self, x, weight, bias, output_padding):
        if self.padding_type == PaddingType.SAME:
            out = self.transposeconv1d_same_padding(x, weight, bias,
                                                    output_padding)
        else:
            out = conv_transpose1d(x, weight, bias, self.stride, self.padding,
                                   output_padding, self.groups, self.dilation)
        return out

    def transposeconv1d_same_padding(self, x, weight, bias, output_padding):
        raise Exception("SAME PADDING not supported for ConvTranspose1d")

    def merge_bn_in(self, bn, affine_only, sign_only):
        raise Exception("Merged Batch-Normalization is not yet supported")

    def max_output_bit_width(self, input_bit_width, weight_bit_width):
        max_uint_input = max_uint(bit_width=input_bit_width,
                                  narrow_range=False)
        max_kernel_val = self.weight_quant.tensor_quant.int_quant.max_uint(
            weight_bit_width)
        group_size = self.out_channels // self.groups
        overlapping_sums = max(round(self.kernel_size[0] / self.stride[0]), 1)
        max_uint_output = max_uint_input * max_kernel_val * overlapping_sums * group_size
        max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
        return max_output_bit_width
Ejemplo n.º 2
0
class QuantConv2d(QuantLayer, Conv2d):
    """

        Parameters
        ----------

        %(weight_quant_proxy.parameters_with_prefix)s
    """
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: Union[int, Tuple[int, int]],
            stride: Union[int, Tuple[int, int]] = 1,
            padding: Union[int, Tuple[int, int]] = 0,
            padding_type: PaddingType = PaddingType.STANDARD,
            dilation: Union[int, Tuple[int, int]] = 1,
            groups: int = 1,
            bias: bool = True,
            bias_quant_type: QuantType = QuantType.FP,
            bias_narrow_range: bool = False,
            bias_bit_width: int = None,
            weight_quant_override: WeightQuantProxy = None,
            weight_quant_type: QuantType = QuantType.FP,
            weight_narrow_range: bool = False,
            weight_scaling_override: Optional[Module] = None,
            weight_bit_width_impl_override: Union[BitWidthParameter,
                                                  BitWidthConst] = None,
            weight_bit_width_impl_type: BitWidthImplType = BitWidthImplType.
        CONST,
            weight_restrict_bit_width_type:
        RestrictValueType = RestrictValueType.INT,
            weight_bit_width: int = 32,
            weight_min_overall_bit_width: Optional[int] = 2,
            weight_max_overall_bit_width: Optional[int] = None,
            weight_scaling_impl_type: ScalingImplType = ScalingImplType.STATS,
            weight_scaling_const: Optional[float] = None,
            weight_scaling_stats_op: StatsOp = StatsOp.MAX,
            weight_scaling_per_output_channel: bool = False,
            weight_ternary_threshold: float = 0.5,
            weight_restrict_scaling_type: RestrictValueType = RestrictValueType
        .LOG_FP,
            weight_scaling_stats_sigma: float = 3.0,
            weight_scaling_min_val: float = SCALING_MIN_VAL,
            weight_override_pretrained_bit_width: bool = False,
            compute_output_scale: bool = False,
            compute_output_bit_width: bool = False,
            return_quant_tensor: bool = False) -> None:
        QuantLayer.__init__(self,
                            compute_output_scale=compute_output_scale,
                            compute_output_bit_width=compute_output_bit_width,
                            return_quant_tensor=return_quant_tensor)
        Conv2d.__init__(self,
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=padding,
                        dilation=dilation,
                        groups=groups,
                        bias=bias)
        if weight_quant_type == QuantType.FP and compute_output_bit_width:
            raise Exception(
                "Computing output bit width requires enabling quantization")
        if bias_quant_type != QuantType.FP and not (compute_output_scale and
                                                    compute_output_bit_width):
            raise Exception(
                "Quantizing bias requires to compute output scale and output bit width"
            )

        self.per_elem_ops = 2 * self.kernel_size[0] * self.kernel_size[1] * (
            in_channels // groups)
        self.padding_type = padding_type
        self.weight_reg = WeightReg()

        if weight_quant_override is not None:
            self.weight_quant = weight_quant_override
            self.weight_quant.add_tracked_parameter(self.weight)
        else:
            weight_scaling_stats_input_concat_dim = 1
            if weight_scaling_per_output_channel:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                weight_scaling_shape = self.per_output_channel_broadcastable_shape
                weight_scaling_stats_reduce_dim = 1
            else:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_TENSOR
                weight_scaling_shape = SCALING_SCALAR_SHAPE
                weight_scaling_stats_reduce_dim = None

            if weight_scaling_stats_op == StatsOp.MAX_AVE:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                weight_scaling_stats_reduce_dim = 1

            self.weight_quant = WeightQuantProxy(
                bit_width=weight_bit_width,
                quant_type=weight_quant_type,
                narrow_range=weight_narrow_range,
                scaling_override=weight_scaling_override,
                restrict_scaling_type=weight_restrict_scaling_type,
                scaling_const=weight_scaling_const,
                scaling_stats_op=weight_scaling_stats_op,
                scaling_impl_type=weight_scaling_impl_type,
                scaling_stats_reduce_dim=weight_scaling_stats_reduce_dim,
                scaling_shape=weight_scaling_shape,
                bit_width_impl_type=weight_bit_width_impl_type,
                bit_width_impl_override=weight_bit_width_impl_override,
                restrict_bit_width_type=weight_restrict_bit_width_type,
                min_overall_bit_width=weight_min_overall_bit_width,
                max_overall_bit_width=weight_max_overall_bit_width,
                tracked_parameter_list_init=self.weight,
                ternary_threshold=weight_ternary_threshold,
                scaling_stats_input_view_shape_impl=
                weight_stats_input_view_shape_impl,
                scaling_stats_input_concat_dim=
                weight_scaling_stats_input_concat_dim,
                scaling_stats_sigma=weight_scaling_stats_sigma,
                scaling_min_val=weight_scaling_min_val,
                override_pretrained_bit_width=
                weight_override_pretrained_bit_width)
        self.bias_quant = BiasQuantProxy(quant_type=bias_quant_type,
                                         bit_width=bias_bit_width,
                                         narrow_range=bias_narrow_range)

    @property
    def per_output_channel_broadcastable_shape(self):
        if self.transposed:
            raise Exception("Transposed filters are not supported.")
        else:
            output_dim = 0
        per_channel_size = [1] * len(self.weight.size())
        per_channel_size[output_dim] = self.out_channels
        per_channel_size = tuple(per_channel_size)
        return per_channel_size

    @property
    def int_weight(self):
        if isinstance(self.weight_quant.tensor_quant, IdentityQuant):
            raise Exception(
                "Can't export int weight without quantization enabled")
        return self.weight_quant.int_weight(self.weight)

    @property
    def quant_weight_scale(self):
        if isinstance(self.weight_quant.tensor_quant, IdentityQuant):
            raise Exception(
                "Can't generate scaling factor without quantization enabled")
        zero_hw_sentinel = self.weight_quant.zero_hw_sentinel
        return self.weight_quant.tensor_quant.scaling_impl(zero_hw_sentinel)

    def forward(self, input):
        output_scale = None
        output_bit_width = None
        input, input_scale, input_bit_width = self.unpack_input(input)
        quant_weight, quant_weight_scale, quant_weight_bit_width = self.weight_quant(
            self.weight)
        quant_weight = self.weight_reg(quant_weight)

        if self.compute_output_bit_width:
            output_bit_width = self.max_output_bit_width(
                input_bit_width, quant_weight_bit_width)
        if self.compute_output_scale:
            output_scale = input_scale * quant_weight_scale

        if self.bias is not None:
            quant_bias = self.bias_quant(self.bias, output_scale,
                                         output_bit_width)
            output = self.conv2d(input, quant_weight, quant_bias)
        else:
            output = self.conv2d(input, quant_weight, None)
        return self.pack_output(output, output_scale, output_bit_width)

    def conv2d(self, x, weight, bias):
        if self.padding_type == PaddingType.SAME:
            out = self.conv2d_same_padding(x, weight, bias)
        else:
            out = conv2d(x, weight, bias, self.stride, self.padding,
                         self.dilation, self.groups)
        return out

    def conv2d_same_padding(self, x, weight, bias):
        ih, iw = x.size()[-2:]
        kh, kw = weight.size()[-2:]
        sh, sw = self.stride
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
        pad_h = max(
            (oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih,
            0)
        pad_w = max(
            (ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw,
            0)
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, [
                pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
            ])
        out = F.conv2d(x, weight, bias, self.stride, 0, self.dilation,
                       self.groups)
        return out

    def merge_bn_in(self, bn, affine_only, sign_only):
        if sign_only and not isinstance(bn, QuantBatchNorm2d):
            raise Exception("Sign-only supported only with QuantBatchNorm2d")
        if affine_only and not bn.affine:
            raise Exception(
                "Affine-only merging requires BN to have affine scaling enabled."
            )
        if sign_only:
            self.weight.data *= bn.weight_sign.view(
                self.per_output_channel_broadcastable_shape)
        else:
            mul_factor, add_factor = mul_add_from_bn(
                bn_mean=bn.running_mean,
                bn_var=bn.running_var,
                bn_eps=bn.eps,
                bn_weight=bn.weight.data.clone(),
                bn_bias=bn.bias.data.clone(),
                affine_only=affine_only)
            self.weight.data *= mul_factor.view(
                self.per_output_channel_broadcastable_shape)
            if self.bias is not None:
                self.bias.data += add_factor
            else:
                self.bias = Parameter(add_factor)

    def max_output_bit_width(self, input_bit_width, weight_bit_width):
        max_uint_input = max_uint(bit_width=input_bit_width,
                                  narrow_range=False)
        max_kernel_val = self.weight_quant.tensor_quant.int_quant.max_uint(
            weight_bit_width)
        group_size = self.out_channels // self.groups
        max_uint_output = max_uint_input * max_kernel_val * self.kernel_size[
            0] * self.kernel_size[1] * group_size
        max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
        return max_output_bit_width
Ejemplo n.º 3
0
class QuantConv2d(QuantLayer, Conv2d):
    """

        Parameters
        ----------

        %(weight_quant_proxy.parameters_with_prefix)s
    """
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: Union[int, Tuple[int, int]],
            stride: Union[int, Tuple[int, int]] = 1,
            padding: Union[int, Tuple[int, int]] = 0,
            padding_type: PaddingType = PaddingType.STANDARD,
            dilation: Union[int, Tuple[int, int]] = 1,
            groups: int = 1,
            bias: bool = True,
            bias_quant_type: QuantType = QuantType.FP,
            bias_narrow_range: bool = False,
            bias_bit_width: int = None,
            weight_quant_override: WeightQuantProxy = None,
            weight_quant_type: QuantType = QuantType.FP,
            weight_narrow_range: bool = False,
            weight_scaling_override: Optional[Module] = None,
            weight_bit_width_impl_override: Union[BitWidthParameter,
                                                  BitWidthConst] = None,
            weight_bit_width_impl_type: BitWidthImplType = BitWidthImplType.
        CONST,
            weight_restrict_bit_width_type:
        RestrictValueType = RestrictValueType.INT,
            weight_bit_width: int = 32,
            weight_min_overall_bit_width: Optional[int] = 2,
            weight_max_overall_bit_width: Optional[int] = None,
            weight_scaling_impl_type: ScalingImplType = ScalingImplType.STATS,
            weight_scaling_const: Optional[float] = None,
            weight_scaling_stats_op: StatsOp = StatsOp.MAX,
            weight_scaling_per_output_channel: bool = False,
            weight_ternary_threshold: float = 0.5,
            weight_restrict_scaling_type: RestrictValueType = RestrictValueType
        .LOG_FP,
            weight_scaling_stats_sigma: float = 3.0,
            weight_scaling_min_val: float = SCALING_MIN_VAL,
            weight_override_pretrained_bit_width: bool = False,
            compute_output_scale: bool = False,
            compute_output_bit_width: bool = False,
            return_quant_tensor: bool = False) -> None:
        QuantLayer.__init__(self,
                            compute_output_scale=compute_output_scale,
                            compute_output_bit_width=compute_output_bit_width,
                            return_quant_tensor=return_quant_tensor)
        Conv2d.__init__(self,
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=padding,
                        dilation=dilation,
                        groups=groups,
                        bias=bias)
        # save a copy of args passed constructor, used to determine whether
        # the quantization config is exportable to something FINN supports
        self.init_args = locals()
        if weight_quant_type == QuantType.FP and compute_output_bit_width:
            raise Exception(
                "Computing output bit width requires enabling quantization")
        if bias_quant_type != QuantType.FP and not (compute_output_scale and
                                                    compute_output_bit_width):
            raise Exception(
                "Quantizing bias requires to compute output scale and output bit width"
            )

        self.per_elem_ops = 2 * self.kernel_size[0] * self.kernel_size[1] * (
            in_channels // groups)
        self.padding_type = padding_type
        self.weight_reg = WeightReg()

        if weight_quant_override is not None:
            self.weight_quant = weight_quant_override
            self.weight_quant.add_tracked_parameter(self.weight)
        else:
            weight_scaling_stats_input_concat_dim = 0
            if weight_scaling_per_output_channel:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                weight_scaling_shape = self.per_output_channel_broadcastable_shape
                weight_scaling_stats_reduce_dim = 1
            else:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_TENSOR
                weight_scaling_shape = SCALING_SCALAR_SHAPE
                weight_scaling_stats_reduce_dim = None

            if weight_scaling_stats_op == StatsOp.MAX_AVE:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                weight_scaling_stats_reduce_dim = 1

            self.weight_quant = WeightQuantProxy(
                bit_width=weight_bit_width,
                quant_type=weight_quant_type,
                narrow_range=weight_narrow_range,
                scaling_override=weight_scaling_override,
                restrict_scaling_type=weight_restrict_scaling_type,
                scaling_const=weight_scaling_const,
                scaling_stats_op=weight_scaling_stats_op,
                scaling_impl_type=weight_scaling_impl_type,
                scaling_stats_reduce_dim=weight_scaling_stats_reduce_dim,
                scaling_shape=weight_scaling_shape,
                bit_width_impl_type=weight_bit_width_impl_type,
                bit_width_impl_override=weight_bit_width_impl_override,
                restrict_bit_width_type=weight_restrict_bit_width_type,
                min_overall_bit_width=weight_min_overall_bit_width,
                max_overall_bit_width=weight_max_overall_bit_width,
                tracked_parameter_list_init=self.weight,
                ternary_threshold=weight_ternary_threshold,
                scaling_stats_input_view_shape_impl=
                weight_stats_input_view_shape_impl,
                scaling_stats_input_concat_dim=
                weight_scaling_stats_input_concat_dim,
                scaling_stats_sigma=weight_scaling_stats_sigma,
                scaling_min_val=weight_scaling_min_val,
                override_pretrained_bit_width=
                weight_override_pretrained_bit_width)
        self.bias_quant = BiasQuantProxy(quant_type=bias_quant_type,
                                         bit_width=bias_bit_width,
                                         narrow_range=bias_narrow_range)

    def get_exportable_quantization_type(self):
        # Brevitas provides a wide range of possibilities for quantization,
        # but FINN only supports a subset. Here we test the quantization
        # config to see if it's something that FINN would understand.
        # TODO: the checks below are overly conservative, relax these.
        # alternatively, create specialized subclasses and only provide export
        # flows for those.
        ia = self.init_args
        if (ia["bias"] == False and ia["weight_quant_type"] == QuantType.BINARY
                and ia["weight_bit_width"] == 1
                and ia["weight_bit_width_impl_type"] == BitWidthImplType.CONST
                and ia["weight_quant_override"] == None
                and ia["weight_bit_width_impl_override"] == None
                and ia["weight_bit_width_impl_type"] == BitWidthImplType.CONST
                and ia["weight_restrict_bit_width_type"]
                == RestrictValueType.INT
                and ia["weight_min_overall_bit_width"] == 2
                and ia["weight_max_overall_bit_width"] == None
                and ia["weight_ternary_threshold"] == 0.5 and
                ia["weight_restrict_scaling_type"] == RestrictValueType.LOG_FP
                and ia["weight_override_pretrained_bit_width"] == False
                and ia["compute_output_scale"] == False
                and ia["compute_output_bit_width"] == False
                and ia["return_quant_tensor"] == False
                and ia["padding_type"] == PaddingType.STANDARD
                and ia["groups"] == 1 and ia["dilation"] == 1):
            return "BIPOLAR"
        elif (ia["bias"] == False and ia["weight_quant_type"] == QuantType.INT
              and ia["weight_bit_width_impl_type"] == BitWidthImplType.CONST
              and ia["weight_quant_override"] == None
              and ia["weight_bit_width_impl_override"] == None
              and ia["weight_bit_width_impl_type"] == BitWidthImplType.CONST
              and ia["weight_restrict_bit_width_type"] == RestrictValueType.INT
              and ia["weight_min_overall_bit_width"] == 2
              and ia["weight_max_overall_bit_width"] == None
              and ia["weight_override_pretrained_bit_width"] == False
              and ia["compute_output_scale"] == False
              and ia["compute_output_bit_width"] == False
              and ia["return_quant_tensor"] == False):
            return "INT%d" % ia["weight_bit_width"]
        else:
            raise Exception("Unsupported config combination for export")

    @property
    def per_output_channel_broadcastable_shape(self):
        if self.transposed:
            raise Exception("Transposed filters are not supported.")
        else:
            output_dim = 0
        per_channel_size = [1] * len(self.weight.size())
        per_channel_size[output_dim] = self.out_channels
        per_channel_size = tuple(per_channel_size)
        return per_channel_size

    @property
    def int_weight(self):
        if isinstance(self.weight_quant.tensor_quant, IdentityQuant):
            raise Exception(
                "Can't export int weight without quantization enabled")
        return self.weight_quant.int_weight(self.weight)

    @property
    def quant_weight_scale(self):
        """

        Returns scale factor of the quantized weights with scalar () shape or (self.out_channels, 1, 1, 1)
        shape depending on whether scaling is per layer or per-channel.
        -------

        """
        if isinstance(self.weight_quant.tensor_quant, IdentityQuant):
            raise Exception(
                "Can't generate scaling factor without quantization enabled")
        zero_hw_sentinel = self.weight_quant.zero_hw_sentinel
        _, scale, _ = self.weight_quant.tensor_quant(self.weight,
                                                     zero_hw_sentinel)
        return scale

    @QuantLayer.export_mode.setter
    def export_mode(self, value):
        self._export_mode = value
        # create completely detached prequantized tensors for export
        # calling these in forward() causes the ops to be included in the graph
        # as dead-end nodes. note: this might be fixed in PyTorch 1.2.0 and
        # if so this workaround prepare_for_export is not necessary.
        self.export_int_weight = self.int_weight.type(
            torch.FloatTensor).detach()
        self.export_quant_weight_scale = self.quant_weight_scale.type(
            torch.FloatTensor).detach()

    def forward(self, input):
        if self.export_mode:
            export_qnt_type = self.get_exportable_quantization_type()
            # provide padding and stride as lists, one elem per dimension
            # onnxruntime expects a 4D padding list, so we make one
            if isinstance(self.padding, int):
                export_pads = [
                    self.padding, self.padding, self.padding, self.padding
                ]
            else:
                # assume we have a tuple and symmetric padding
                # [x1_begin, x2_begin...x1_end, x2_end,...],
                # so just duplicate the padding tuple
                export_pads = list(self.padding) + list(self.padding)
            if isinstance(self.stride, int):
                export_strides = [self.stride, self.stride]
            else:
                export_strides = list(self.stride)
            export_scale = self.export_quant_weight_scale
            # if we have channelwise scaling factors, reshape them to match
            # the output shape while exporting
            if len(export_scale.shape) == 4:
                export_scale_shape = (1, self.out_channels, 1, 1)
                export_scale = export_scale.reshape(export_scale_shape)
            # TODO add support for biases
            export_bias = None
            return QuantizedConv2dPlaceholderFunction.apply(
                input, self.export_int_weight, export_scale, export_qnt_type,
                self.export_out_shape, export_pads, export_strides,
                export_bias, list(self.kernel_size), self.groups)
        else:
            output_scale = None
            output_bit_width = None
            quant_bias_bit_width = None

            input, input_scale, input_bit_width = self.unpack_input(input)
            quant_weight, quant_weight_scale, quant_weight_bit_width = self.weight_quant(
                self.weight)
            quant_weight = self.weight_reg(quant_weight)

            if self.compute_output_bit_width:
                assert input_bit_width is not None
                output_bit_width = self.max_output_bit_width(
                    input_bit_width, quant_weight_bit_width)
            if self.compute_output_scale:
                assert input_scale is not None
                output_scale = input_scale * quant_weight_scale

            if self.bias is not None:
                quant_bias, _, quant_bias_bit_width = self.bias_quant(
                    self.bias, output_scale, output_bit_width)
                output = self.conv2d(input, quant_weight, quant_bias)
            else:
                output = self.conv2d(input, quant_weight, None)

            if self.compute_output_bit_width and quant_bias_bit_width is not None:
                output_bit_width = torch.where(
                    quant_bias_bit_width > output_bit_width,
                    quant_bias_bit_width, output_bit_width)
                output_bit_width = output_bit_width + 1

            return self.pack_output(output, output_scale, output_bit_width)

    def conv2d(self, x, weight, bias):
        if self.padding_type == PaddingType.SAME:
            out = self.conv2d_same_padding(x, weight, bias)
        else:
            out = conv2d(x, weight, bias, self.stride, self.padding,
                         self.dilation, self.groups)
        return out

    def conv2d_same_padding(self, x, weight, bias):
        ih, iw = x.size()[-2:]
        kh, kw = weight.size()[-2:]
        sh, sw = self.stride
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
        pad_h = max(
            (oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih,
            0)
        pad_w = max(
            (ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw,
            0)
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, [
                pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
            ])
        out = F.conv2d(x, weight, bias, self.stride, 0, self.dilation,
                       self.groups)
        return out

    def merge_bn_in(self, bn, affine_only):
        if affine_only and not bn.affine:
            raise Exception(
                "Affine-only merging requires BN to have affine scaling enabled."
            )
        else:
            mul_factor, add_factor = mul_add_from_bn(
                bn_mean=bn.running_mean,
                bn_var=bn.running_var,
                bn_eps=bn.eps,
                bn_weight=bn.weight.data.clone(),
                bn_bias=bn.bias.data.clone(),
                affine_only=affine_only)
            self.weight.data *= mul_factor.view(
                self.per_output_channel_broadcastable_shape)
            if self.bias is not None:
                self.bias.data += add_factor
            else:
                self.bias = Parameter(add_factor)

    def max_output_bit_width(self, input_bit_width, weight_bit_width):
        max_uint_input = max_uint(bit_width=input_bit_width,
                                  narrow_range=False)
        max_kernel_val = self.weight_quant.tensor_quant.int_quant.max_uint(
            weight_bit_width)
        group_size = self.out_channels // self.groups
        max_uint_output = max_uint_input * max_kernel_val * self.kernel_size[
            0] * self.kernel_size[1] * group_size
        max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
        return max_output_bit_width
Ejemplo n.º 4
0
class QuantLinear(QuantLayer, Linear):
    """

        Parameters
        ----------

        %(weight_quant_proxy.parameters_with_prefix)s
    """
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool,
                 bias_quant_type: QuantType = QuantType.FP,
                 bias_narrow_range: bool = False,
                 bias_bit_width: int = None,
                 weight_quant_override: WeightQuantProxy = None,
                 weight_quant_type: QuantType = QuantType.FP,
                 weight_narrow_range: bool = False,
                 weight_bit_width_impl_override: Union[BitWidthParameter, BitWidthConst] = None,
                 weight_bit_width_impl_type: BitWidthImplType = BitWidthImplType.CONST,
                 weight_restrict_bit_width_type: RestrictValueType = RestrictValueType.INT,
                 weight_bit_width: int = 32,
                 weight_min_overall_bit_width: Optional[int] = 2,
                 weight_max_overall_bit_width: Optional[int] = None,
                 weight_scaling_override: Optional[Module] = None,
                 weight_scaling_impl_type: ScalingImplType = ScalingImplType.STATS,
                 weight_scaling_const: Optional[float] = None,
                 weight_scaling_stats_op: StatsOp = StatsOp.MAX,
                 weight_scaling_per_output_channel: bool = False,
                 weight_scaling_min_val: float = SCALING_MIN_VAL,
                 weight_ternary_threshold: float = 0.5,
                 weight_restrict_scaling_type: RestrictValueType = RestrictValueType.LOG_FP,
                 weight_scaling_stats_sigma: float = 3.0,
                 weight_override_pretrained_bit_width: bool = False,
                 compute_output_scale: bool = False,
                 compute_output_bit_width: bool = False,
                 return_quant_tensor: bool = False) -> None:
        QuantLayer.__init__(self,
                            compute_output_scale=compute_output_scale,
                            compute_output_bit_width=compute_output_bit_width,
                            return_quant_tensor=return_quant_tensor)
        Linear.__init__(self,
                        in_features=in_features,
                        out_features=out_features,
                        bias=bias)
        if weight_quant_type == QuantType.FP and compute_output_bit_width:
            raise Exception("Computing output bit width requires enabling quantization")
        if bias_quant_type != QuantType.FP and not (compute_output_scale and compute_output_bit_width):
            raise Exception("Quantizing bias requires to compute output scale and output bit width")

        self.per_elem_ops = 2 * in_features
        self.weight_reg = WeightReg()

        if weight_quant_override is not None:
            self.weight_quant = weight_quant_override
            self.weight_quant.add_tracked_tensor(self.weight)
        else:
            weight_scaling_stats_input_concat_dim = 1
            if weight_scaling_per_output_channel:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                weight_scaling_shape = (self.out_features, 1)
                weight_scaling_stats_reduce_dim = 1
            else:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_TENSOR
                weight_scaling_shape = SCALING_SCALAR_SHAPE
                weight_scaling_stats_reduce_dim = None

            self.weight_quant = WeightQuantProxy(bit_width=weight_bit_width,
                                                 quant_type=weight_quant_type,
                                                 narrow_range=weight_narrow_range,
                                                 scaling_override=weight_scaling_override,
                                                 restrict_scaling_type=weight_restrict_scaling_type,
                                                 scaling_const=weight_scaling_const,
                                                 scaling_stats_op=weight_scaling_stats_op,
                                                 scaling_impl_type=weight_scaling_impl_type,
                                                 scaling_stats_reduce_dim=weight_scaling_stats_reduce_dim,
                                                 scaling_shape=weight_scaling_shape,
                                                 bit_width_impl_type=weight_bit_width_impl_type,
                                                 bit_width_impl_override=weight_bit_width_impl_override,
                                                 restrict_bit_width_type=weight_restrict_bit_width_type,
                                                 min_overall_bit_width=weight_min_overall_bit_width,
                                                 max_overall_bit_width=weight_max_overall_bit_width,
                                                 tracked_parameter_list_init=self.weight,
                                                 ternary_threshold=weight_ternary_threshold,
                                                 scaling_stats_input_view_shape_impl=weight_stats_input_view_shape_impl,
                                                 scaling_stats_input_concat_dim=weight_scaling_stats_input_concat_dim,
                                                 scaling_stats_sigma=weight_scaling_stats_sigma,
                                                 scaling_min_val=weight_scaling_min_val,
                                                 override_pretrained_bit_width=weight_override_pretrained_bit_width)
        self.bias_quant = BiasQuantProxy(quant_type=bias_quant_type,
                                         narrow_range=bias_narrow_range,
                                         bit_width=bias_bit_width)

    @property
    def int_weight(self):
        if isinstance(self.weight_quant.tensor_quant, IdentityQuant):
            raise Exception("Can't generate int weight without quantization enabled")
        return self.weight_quant.int_weight(self.weight)

    @property
    def quant_weight_scale(self):
        """

        Returns scale factor of the quantized weights with scalar () shape or (self.out_channels, 1)
        shape depending on whether scaling is per layer or per-channel.
        -------

        """
        if isinstance(self.weight_quant.tensor_quant, IdentityQuant):
            raise Exception("Can't generate scaling factor without quantization enabled")
        zero_hw_sentinel = self.weight_quant.zero_hw_sentinel
        _, scale, _ = self.weight_quant.tensor_quant(self.weight, zero_hw_sentinel)
        return scale

    def forward(self, input):
        output_scale = None
        output_bit_width = None
        bias_bit_width = None

        input, input_scale, input_bit_width = self.unpack_input(input)

        quant_weight, quant_weight_scale, quant_weight_bit_width = self.weight_quant(self.weight)
        quant_weight = self.weight_reg(quant_weight)

        if self.compute_output_bit_width:
            assert input_bit_width is not None
            output_bit_width = self.max_output_bit_width(input_bit_width, quant_weight_bit_width)
        if self.compute_output_scale:
            assert input_scale is not None
            output_scale = input_scale * quant_weight_scale

        if self.bias is not None:
            quant_bias, _, bias_bit_width = self.bias_quant(self.bias, output_scale, output_bit_width)
            output = linear(input, quant_weight, quant_bias)
        else:
            output = linear(input, quant_weight, None)

        if self.compute_output_bit_width and bias_bit_width is not None:
            output_bit_width = torch.where(bias_bit_width > output_bit_width, bias_bit_width, output_bit_width)

        return self.pack_output(output, output_scale, output_bit_width)

    def max_output_bit_width(self, input_bit_width, weight_bit_width):
        max_input_val = max_uint(bit_width=input_bit_width, narrow_range=False)
        max_fc_val = self.weight_quant.tensor_quant.int_quant.max_uint(weight_bit_width)
        max_output_val = max_input_val * max_fc_val * self.in_features
        output_bit_width = ceil_ste(torch.log2(max_output_val))
        return output_bit_width
Ejemplo n.º 5
0
class QuantLinear(QuantLayer, Linear):
    """

        Parameters
        ----------

        %(weight_quant_proxy.parameters_with_prefix)s
    """
    def __init__(
            self,
            in_features: int,
            out_features: int,
            bias: bool,
            bias_quant_type: QuantType = QuantType.FP,
            bias_narrow_range: bool = False,
            bias_bit_width: int = None,
            weight_quant_override: WeightQuantProxy = None,
            weight_quant_type: QuantType = QuantType.FP,
            weight_narrow_range: bool = False,
            weight_bit_width_impl_override: Union[BitWidthParameter,
                                                  BitWidthConst] = None,
            weight_bit_width_impl_type: BitWidthImplType = BitWidthImplType.
        CONST,
            weight_restrict_bit_width_type:
        RestrictValueType = RestrictValueType.INT,
            weight_bit_width: int = 32,
            weight_min_overall_bit_width: Optional[int] = 2,
            weight_max_overall_bit_width: Optional[int] = None,
            weight_scaling_override: Optional[Module] = None,
            weight_scaling_impl_type: ScalingImplType = ScalingImplType.STATS,
            weight_scaling_const: Optional[float] = None,
            weight_scaling_stats_op: StatsOp = StatsOp.MAX,
            weight_scaling_per_output_channel: bool = False,
            weight_scaling_min_val: float = SCALING_MIN_VAL,
            weight_ternary_threshold: float = 0.5,
            weight_restrict_scaling_type: RestrictValueType = RestrictValueType
        .LOG_FP,
            weight_scaling_stats_sigma: float = 3.0,
            weight_override_pretrained_bit_width: bool = False,
            compute_output_scale: bool = False,
            compute_output_bit_width: bool = False,
            return_quant_tensor: bool = False) -> None:
        QuantLayer.__init__(self,
                            compute_output_scale=compute_output_scale,
                            compute_output_bit_width=compute_output_bit_width,
                            return_quant_tensor=return_quant_tensor)
        Linear.__init__(self,
                        in_features=in_features,
                        out_features=out_features,
                        bias=bias)
        # save a copy of args passed constructor, used to determine whether
        # the quantization config is exportable to something FINN supports
        self.init_args = locals()
        if weight_quant_type == QuantType.FP and compute_output_bit_width:
            raise Exception(
                "Computing output bit width requires enabling quantization")
        if bias_quant_type != QuantType.FP and not (compute_output_scale and
                                                    compute_output_bit_width):
            raise Exception(
                "Quantizing bias requires to compute output scale and output bit width"
            )

        self.per_elem_ops = 2 * in_features
        self.weight_reg = WeightReg()

        if weight_quant_override is not None:
            self.weight_quant = weight_quant_override
            self.weight_quant.add_tracked_tensor(self.weight)
        else:
            weight_scaling_stats_input_concat_dim = 0
            if weight_scaling_per_output_channel:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                weight_scaling_shape = (self.out_features, 1)
                weight_scaling_stats_reduce_dim = 1
            else:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_TENSOR
                weight_scaling_shape = SCALING_SCALAR_SHAPE
                weight_scaling_stats_reduce_dim = None

            if weight_scaling_stats_op == StatsOp.MAX_AVE:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                weight_scaling_stats_reduce_dim = 1

            self.weight_quant = WeightQuantProxy(
                bit_width=weight_bit_width,
                quant_type=weight_quant_type,
                narrow_range=weight_narrow_range,
                scaling_override=weight_scaling_override,
                restrict_scaling_type=weight_restrict_scaling_type,
                scaling_const=weight_scaling_const,
                scaling_stats_op=weight_scaling_stats_op,
                scaling_impl_type=weight_scaling_impl_type,
                scaling_stats_reduce_dim=weight_scaling_stats_reduce_dim,
                scaling_shape=weight_scaling_shape,
                bit_width_impl_type=weight_bit_width_impl_type,
                bit_width_impl_override=weight_bit_width_impl_override,
                restrict_bit_width_type=weight_restrict_bit_width_type,
                min_overall_bit_width=weight_min_overall_bit_width,
                max_overall_bit_width=weight_max_overall_bit_width,
                tracked_parameter_list_init=self.weight,
                ternary_threshold=weight_ternary_threshold,
                scaling_stats_input_view_shape_impl=
                weight_stats_input_view_shape_impl,
                scaling_stats_input_concat_dim=
                weight_scaling_stats_input_concat_dim,
                scaling_stats_sigma=weight_scaling_stats_sigma,
                scaling_min_val=weight_scaling_min_val,
                override_pretrained_bit_width=
                weight_override_pretrained_bit_width)
        self.bias_quant = BiasQuantProxy(quant_type=bias_quant_type,
                                         narrow_range=bias_narrow_range,
                                         bit_width=bias_bit_width)

    def get_exportable_quantization_type(self):
        # Brevitas provides a wide range of possibilities for quantization,
        # but FINN only supports a subset. Here we test the quantization
        # config to see if it's something that FINN would understand.
        # TODO: the checks below are overly conservative, relax these.
        # alternatively, create specialized subclasses and only provide export
        # flows for those.
        ia = self.init_args
        if (ia["weight_quant_type"] == QuantType.BINARY
                and ia["weight_bit_width"] == 1
                and ia["weight_bit_width_impl_type"] == BitWidthImplType.CONST
                and ia["weight_quant_override"] == None
                and ia["weight_bit_width_impl_override"] == None
                and ia["weight_bit_width_impl_type"] == BitWidthImplType.CONST
                and ia["weight_restrict_bit_width_type"]
                == RestrictValueType.INT
                and ia["weight_min_overall_bit_width"] == 2
                and ia["weight_max_overall_bit_width"] == None
                and ia["weight_override_pretrained_bit_width"] == False
                and ia["compute_output_scale"] == False
                and ia["compute_output_bit_width"] == False
                and ia["return_quant_tensor"] == False):
            return "BIPOLAR"
        elif (ia["weight_quant_type"] == QuantType.INT
              and ia["weight_bit_width_impl_type"] == BitWidthImplType.CONST
              and ia["weight_quant_override"] == None
              and ia["weight_bit_width_impl_override"] == None
              and ia["weight_bit_width_impl_type"] == BitWidthImplType.CONST
              and ia["weight_restrict_bit_width_type"] == RestrictValueType.INT
              and ia["weight_min_overall_bit_width"] == 2
              and ia["weight_max_overall_bit_width"] == None
              and ia["weight_override_pretrained_bit_width"] == False):
            return "INT%d" % ia["weight_bit_width"]
        else:
            raise Exception("Unsupported config combination for export")

    @QuantLayer.export_mode.setter
    def export_mode(self, value):
        self._export_mode = value
        # create completely detached prequantized tensors for export
        # calling these in forward() causes the ops to be included in the graph
        # as dead-end nodes. note: this might be fixed in PyTorch 1.2.0 and
        # if so this workaround prepare_for_export is not necessary.
        self.export_int_weight = torch.t(
            self.int_weight.type(torch.FloatTensor)).detach()
        self.export_quant_weight_scale = torch.t(
            self.quant_weight_scale.type(torch.FloatTensor)).detach()
        self.export_input_node_scale = None

        self.export_output_node_scale = self.export_quant_weight_scale
        if self.compute_output_scale:
            self.export_input_node_scale = self.export_in_scale.type(
                torch.FloatTensor).detach().view(-1)[0]
            self.export_output_node_scale = self.export_out_scale.type(
                torch.FloatTensor).detach().view(-1)
            if len(self.export_output_node_scale) == 1:
                self.export_output_node_scale = self.export_output_node_scale[
                    0]

        if self.bias is not None:
            if self.export_out_scale is not None:
                quant_bias, _, bias_bit_width = self.bias_quant(
                    self.bias, self.export_out_scale,
                    self.export_out_bit_width)

                self.export_bias = torch.t(quant_bias.type(
                    torch.FloatTensor)).detach()
                self.export_bias /= self.export_output_node_scale
                self.export_bias = torch.round(self.export_bias)
            else:
                self.export_bias = torch.t(self.bias.type(
                    torch.FloatTensor)).detach()
                # divide by scale as add is before mult
                self.export_bias /= self.export_output_node_scale

        else:
            self.export_bias = None

        # export input quant type if available
        ibw = self.export_in_bit_width
        if ibw is not None:
            ibw = int(ibw.item())
            if ibw in [2, 4, 8, 16, 32]:
                self.export_in_quant_type = "INT%d" % ibw
            else:
                raise Exception("Unsupported input bitwidth for export")
        else:
            self.export_in_quant_type = ibw

    @property
    def int_weight(self):
        if isinstance(self.weight_quant.tensor_quant, IdentityQuant):
            raise Exception(
                "Can't generate int weight without quantization enabled")
        return self.weight_quant.int_weight(self.weight)

    @property
    def quant_weight_scale(self):
        """

        Returns scale factor of the quantized weights with scalar () shape or (self.out_channels, 1)
        shape depending on whether scaling is per layer or per-channel.
        -------

        """
        if isinstance(self.weight_quant.tensor_quant, IdentityQuant):
            raise Exception(
                "Can't generate scaling factor without quantization enabled")
        zero_hw_sentinel = self.weight_quant.zero_hw_sentinel
        _, scale, _ = self.weight_quant.tensor_quant(self.weight,
                                                     zero_hw_sentinel)
        return scale

    def forward(self, input):
        if self.export_mode:
            export_qnt_type = self.get_exportable_quantization_type()
            ret = QuantizedLinearPlaceholderFunction.apply(
                input, self.export_int_weight, self.export_output_node_scale,
                export_qnt_type, self.out_features, self.export_bias,
                self.export_input_node_scale, self.export_in_quant_type)
            return ret
        else:
            output_scale = None
            output_bit_width = None
            bias_bit_width = None

            input, input_scale, input_bit_width = self.unpack_input(input)

            quant_weight, quant_weight_scale, quant_weight_bit_width = self.weight_quant(
                self.weight)
            quant_weight = self.weight_reg(quant_weight)

            if self.compute_output_bit_width:
                assert input_bit_width is not None
                output_bit_width = self.max_output_bit_width(
                    input_bit_width, quant_weight_bit_width)
            if self.compute_output_scale:
                assert input_scale is not None
                output_scale = input_scale * quant_weight_scale

            if self.bias is not None:
                quant_bias, _, bias_bit_width = self.bias_quant(
                    self.bias, output_scale, output_bit_width)
                output = linear(input, quant_weight, quant_bias)
            else:
                output = linear(input, quant_weight, None)

            if self.compute_output_bit_width and bias_bit_width is not None:
                output_bit_width = torch.where(
                    bias_bit_width > output_bit_width, bias_bit_width,
                    output_bit_width)
                output_bit_width = output_bit_width + 1
            return self.pack_output(output, output_scale, output_bit_width)

    def max_output_bit_width(self, input_bit_width, weight_bit_width):
        max_input_val = max_uint(bit_width=input_bit_width, narrow_range=False)
        max_fc_val = self.weight_quant.tensor_quant.int_quant.max_uint(
            weight_bit_width)
        max_output_val = max_input_val * max_fc_val * self.in_features
        output_bit_width = ceil_ste(torch.log2(max_output_val))
        return output_bit_width