Beispiel #1
0
    def __init__(self,
                 num_features,
                 eps=1e-5,
                 momentum=0.1,
                 restrict_value_type: RestrictValueType = RestrictValueType.FP,
                 impl_type: ScalingImplType = ScalingImplType.STATS,
                 bias_quant_type: QuantType = QuantType.FP,
                 bias_narrow_range: bool = False,
                 bias_bit_width: int = None):
        QuantLayer.__init__(self,
                            compute_output_scale=False,
                            compute_output_bit_width=False,
                            return_quant_tensor=False)
        nn.Module.__init__(self)

        if bias_quant_type != QuantType.FP and not (self.compute_output_scale and self.compute_output_bit_width):
            raise Exception("Quantizing bias requires to compute output scale and output bit width")

        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))

        if impl_type == ScalingImplType.PARAMETER_FROM_STATS:
            self.running_mean = None
            self.running_var = None
        elif impl_type == ScalingImplType.STATS:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
        else:
            raise Exception("Scaling mode not supported")

        self.eps = eps
        self.momentum = momentum
        self.impl_type = impl_type
        self.num_features = num_features
        self.restrict_value_type = restrict_value_type
        self.bias_quant = BiasQuantProxy(quant_type=bias_quant_type,
                                         narrow_range=bias_narrow_range,
                                         bit_width=bias_bit_width)
        self.restrict_weight = RestrictValue(restrict_value_type=restrict_value_type,
                                             float_to_int_impl_type=FloatToIntImplType.ROUND,
                                             min_val=None)
        self.restrict_scaling_preprocess = RestrictValue.restrict_value_op(restrict_value_type,
                                                                           restrict_value_op_impl_type=
                                                                           RestrictValueOpImplType.TORCH_MODULE)
Beispiel #2
0
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: Union[int, Tuple[int]],
            stride: Union[int, Tuple[int]] = 1,
            padding: Union[int, Tuple[int]] = 0,
            output_padding: Union[int, Tuple[int]] = 0,
            padding_type: PaddingType = PaddingType.STANDARD,
            dilation: Union[int, Tuple[int]] = 1,
            groups: int = 1,
            bias: bool = True,
            bias_quant_type: QuantType = QuantType.FP,
            bias_narrow_range: bool = False,
            bias_bit_width: int = None,
            weight_quant_override: WeightQuantProxy = None,
            weight_quant_type: QuantType = QuantType.FP,
            weight_narrow_range: bool = False,
            weight_scaling_override: Optional[Module] = None,
            weight_bit_width_impl_override: Union[BitWidthParameter,
                                                  BitWidthConst] = None,
            weight_bit_width_impl_type: BitWidthImplType = BitWidthImplType.
        CONST,
            weight_restrict_bit_width_type:
        RestrictValueType = RestrictValueType.INT,
            weight_bit_width: int = 32,
            weight_min_overall_bit_width: Optional[int] = 2,
            weight_max_overall_bit_width: Optional[int] = None,
            weight_scaling_impl_type: ScalingImplType = ScalingImplType.STATS,
            weight_scaling_const: Optional[float] = None,
            weight_scaling_stats_op: StatsOp = StatsOp.MAX,
            weight_scaling_per_output_channel: bool = False,
            weight_ternary_threshold: float = 0.5,
            weight_restrict_scaling_type: RestrictValueType = RestrictValueType
        .LOG_FP,
            weight_scaling_stats_sigma: float = 3.0,
            weight_scaling_min_val: float = SCALING_MIN_VAL,
            weight_override_pretrained_bit_width: bool = False,
            compute_output_scale: bool = False,
            compute_output_bit_width: bool = False,
            return_quant_tensor: bool = False,
            deterministic: bool = False) -> None:
        QuantLayer.__init__(self,
                            compute_output_scale=compute_output_scale,
                            compute_output_bit_width=compute_output_bit_width,
                            return_quant_tensor=return_quant_tensor)
        ConvTranspose1d.__init__(self,
                                 in_channels=in_channels,
                                 out_channels=out_channels,
                                 kernel_size=kernel_size,
                                 stride=stride,
                                 padding=padding,
                                 output_padding=output_padding,
                                 dilation=dilation,
                                 groups=groups,
                                 bias=bias)
        if weight_quant_type == QuantType.FP and compute_output_bit_width:
            raise Exception(
                "Computing output bit width requires enabling quantization")
        if bias_quant_type != QuantType.FP and not (compute_output_scale and
                                                    compute_output_bit_width):
            raise Exception(
                "Quantizing bias requires to compute output scale and output bit width"
            )

        if torch.backends.cudnn.benchmark:
            torch.backends.cudnn.deterministic = deterministic

        # self.per_elem_ops = 2 * self.kernel_size[0] * (in_channels // groups) # TO DO: Implement op_count
        self.padding_type = padding_type
        self.weight_reg = WeightReg()

        if weight_quant_override is not None:
            self.weight_quant = weight_quant_override
            self.weight_quant.add_tracked_parameter(self.weight)
        else:
            weight_scaling_stats_input_concat_dim = 1
            if weight_scaling_per_output_channel:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                weight_scaling_shape = self.per_output_channel_broadcastable_shape
                weight_scaling_stats_reduce_dim = 1
            else:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_TENSOR
                weight_scaling_shape = SCALING_SCALAR_SHAPE
                weight_scaling_stats_reduce_dim = None

            if weight_scaling_stats_op == StatsOp.MAX_AVE:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                weight_scaling_stats_reduce_dim = 1

            self.weight_quant = WeightQuantProxy(
                bit_width=weight_bit_width,
                quant_type=weight_quant_type,
                narrow_range=weight_narrow_range,
                scaling_override=weight_scaling_override,
                restrict_scaling_type=weight_restrict_scaling_type,
                scaling_const=weight_scaling_const,
                scaling_stats_op=weight_scaling_stats_op,
                scaling_impl_type=weight_scaling_impl_type,
                scaling_stats_reduce_dim=weight_scaling_stats_reduce_dim,
                scaling_shape=weight_scaling_shape,
                bit_width_impl_type=weight_bit_width_impl_type,
                bit_width_impl_override=weight_bit_width_impl_override,
                restrict_bit_width_type=weight_restrict_bit_width_type,
                min_overall_bit_width=weight_min_overall_bit_width,
                max_overall_bit_width=weight_max_overall_bit_width,
                tracked_parameter_list_init=self.weight,
                ternary_threshold=weight_ternary_threshold,
                scaling_stats_input_view_shape_impl=
                weight_stats_input_view_shape_impl,
                scaling_stats_input_concat_dim=
                weight_scaling_stats_input_concat_dim,
                scaling_stats_sigma=weight_scaling_stats_sigma,
                scaling_min_val=weight_scaling_min_val,
                override_pretrained_bit_width=
                weight_override_pretrained_bit_width)
        self.bias_quant = BiasQuantProxy(quant_type=bias_quant_type,
                                         bit_width=bias_bit_width,
                                         narrow_range=bias_narrow_range)
Beispiel #3
0
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool,
                 bias_quant_type: QuantType = QuantType.FP,
                 bias_narrow_range: bool = False,
                 bias_bit_width: int = None,
                 weight_quant_override: WeightQuantProxy = None,
                 weight_quant_type: QuantType = QuantType.FP,
                 weight_narrow_range: bool = False,
                 weight_bit_width_impl_override: Union[BitWidthParameter, BitWidthConst] = None,
                 weight_bit_width_impl_type: BitWidthImplType = BitWidthImplType.CONST,
                 weight_restrict_bit_width_type: RestrictValueType = RestrictValueType.INT,
                 weight_bit_width: int = 32,
                 weight_min_overall_bit_width: Optional[int] = 2,
                 weight_max_overall_bit_width: Optional[int] = None,
                 weight_scaling_override: Optional[Module] = None,
                 weight_scaling_impl_type: ScalingImplType = ScalingImplType.STATS,
                 weight_scaling_const: Optional[float] = None,
                 weight_scaling_stats_op: StatsOp = StatsOp.MAX,
                 weight_scaling_per_output_channel: bool = False,
                 weight_scaling_min_val: float = SCALING_MIN_VAL,
                 weight_ternary_threshold: float = 0.5,
                 weight_restrict_scaling_type: RestrictValueType = RestrictValueType.LOG_FP,
                 weight_scaling_stats_sigma: float = 3.0,
                 weight_override_pretrained_bit_width: bool = False,
                 compute_output_scale: bool = False,
                 compute_output_bit_width: bool = False,
                 return_quant_tensor: bool = False) -> None:
        QuantLayer.__init__(self,
                            compute_output_scale=compute_output_scale,
                            compute_output_bit_width=compute_output_bit_width,
                            return_quant_tensor=return_quant_tensor)
        Linear.__init__(self,
                        in_features=in_features,
                        out_features=out_features,
                        bias=bias)
        if weight_quant_type == QuantType.FP and compute_output_bit_width:
            raise Exception("Computing output bit width requires enabling quantization")
        if bias_quant_type != QuantType.FP and not (compute_output_scale and compute_output_bit_width):
            raise Exception("Quantizing bias requires to compute output scale and output bit width")

        self.per_elem_ops = 2 * in_features
        self.weight_reg = WeightReg()

        if weight_quant_override is not None:
            self.weight_quant = weight_quant_override
            self.weight_quant.add_tracked_tensor(self.weight)
        else:
            weight_scaling_stats_input_concat_dim = 1
            if weight_scaling_per_output_channel:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                weight_scaling_shape = (self.out_features, 1)
                weight_scaling_stats_reduce_dim = 1
            else:
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_TENSOR
                weight_scaling_shape = SCALING_SCALAR_SHAPE
                weight_scaling_stats_reduce_dim = None

            self.weight_quant = WeightQuantProxy(bit_width=weight_bit_width,
                                                 quant_type=weight_quant_type,
                                                 narrow_range=weight_narrow_range,
                                                 scaling_override=weight_scaling_override,
                                                 restrict_scaling_type=weight_restrict_scaling_type,
                                                 scaling_const=weight_scaling_const,
                                                 scaling_stats_op=weight_scaling_stats_op,
                                                 scaling_impl_type=weight_scaling_impl_type,
                                                 scaling_stats_reduce_dim=weight_scaling_stats_reduce_dim,
                                                 scaling_shape=weight_scaling_shape,
                                                 bit_width_impl_type=weight_bit_width_impl_type,
                                                 bit_width_impl_override=weight_bit_width_impl_override,
                                                 restrict_bit_width_type=weight_restrict_bit_width_type,
                                                 min_overall_bit_width=weight_min_overall_bit_width,
                                                 max_overall_bit_width=weight_max_overall_bit_width,
                                                 tracked_parameter_list_init=self.weight,
                                                 ternary_threshold=weight_ternary_threshold,
                                                 scaling_stats_input_view_shape_impl=weight_stats_input_view_shape_impl,
                                                 scaling_stats_input_concat_dim=weight_scaling_stats_input_concat_dim,
                                                 scaling_stats_sigma=weight_scaling_stats_sigma,
                                                 scaling_min_val=weight_scaling_min_val,
                                                 override_pretrained_bit_width=weight_override_pretrained_bit_width)
        self.bias_quant = BiasQuantProxy(quant_type=bias_quant_type,
                                         narrow_range=bias_narrow_range,
                                         bit_width=bias_bit_width)
    def __init__(
            self,
            num_features,
            bias_quant_type: QuantType = QuantType.FP,
            bias_narrow_range: bool = False,
            bias_bit_width: int = None,
            weight_quant_type: QuantType = QuantType.FP,
            weight_quant_override: nn.Module = None,
            weight_narrow_range: bool = False,
            weight_scaling_override: Optional[nn.Module] = None,
            weight_bit_width: int = 32,
            weight_scaling_impl_type: ScalingImplType = ScalingImplType.STATS,
            weight_scaling_const: Optional[float] = None,
            weight_scaling_stats_op: StatsOp = StatsOp.MAX,
            weight_scaling_per_output_channel: bool = False,
            weight_restrict_scaling_type: RestrictValueType = RestrictValueType
        .LOG_FP,
            weight_scaling_stats_sigma: float = 3.0,
            weight_scaling_min_val: float = SCALING_MIN_VAL,
            compute_output_scale: bool = False,
            compute_output_bit_width: bool = False,
            return_quant_tensor: bool = False):
        QuantLayer.__init__(self,
                            compute_output_scale=compute_output_scale,
                            compute_output_bit_width=compute_output_bit_width,
                            return_quant_tensor=return_quant_tensor)
        ScaleBias.__init__(self, num_features)

        if bias_quant_type != QuantType.FP and not self.compute_output_scale:
            raise Exception("Quantizing bias requires to compute output scale")
        if bias_quant_type != QuantType.FP and bias_bit_width is None and not self.compute_output_bit_width:
            raise Exception(
                "Quantizing bias requires a bit-width, either computed or defined"
            )

        if weight_quant_override is not None:
            self.weight_quant = weight_quant_override
            self.weight_quant.add_tracked_parameter(self.weight)
        else:
            weight_scaling_stats_input_concat_dim = 1
            if weight_scaling_stats_op == StatsOp.MAX_AVE:
                assert not weight_scaling_per_output_channel
                weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                weight_scaling_shape = SCALING_SCALAR_SHAPE
                weight_scaling_stats_reduce_dim = None
            else:
                if weight_scaling_per_output_channel:
                    weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                    weight_scaling_shape = (num_features, 1)
                    weight_scaling_stats_reduce_dim = 1
                else:
                    weight_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_TENSOR
                    weight_scaling_shape = SCALING_SCALAR_SHAPE
                    weight_scaling_stats_reduce_dim = None

            self.weight_quant = WeightQuantProxy(
                bit_width=weight_bit_width,
                quant_type=weight_quant_type,
                narrow_range=weight_narrow_range,
                scaling_override=weight_scaling_override,
                restrict_scaling_type=weight_restrict_scaling_type,
                scaling_const=weight_scaling_const,
                scaling_stats_op=weight_scaling_stats_op,
                scaling_impl_type=weight_scaling_impl_type,
                scaling_stats_reduce_dim=weight_scaling_stats_reduce_dim,
                scaling_shape=weight_scaling_shape,
                bit_width_impl_type=BitWidthImplType.CONST,
                bit_width_impl_override=None,
                restrict_bit_width_type=RestrictValueType.INT,
                min_overall_bit_width=None,
                max_overall_bit_width=None,
                tracked_parameter_list_init=self.weight,
                ternary_threshold=None,
                scaling_stats_input_view_shape_impl=
                weight_stats_input_view_shape_impl,
                scaling_stats_input_concat_dim=
                weight_scaling_stats_input_concat_dim,
                scaling_stats_sigma=weight_scaling_stats_sigma,
                scaling_min_val=weight_scaling_min_val,
                override_pretrained_bit_width=None)
        self.bias_quant = BiasQuantProxy(quant_type=bias_quant_type,
                                         narrow_range=bias_narrow_range,
                                         bit_width=bias_bit_width)