Exemple #1
0
 def __init__(
         self,
         bit_width: int,
         min_val: float = -1.0,
         max_val: float = 1.0,
         narrow_range: bool = False,
         quant_type: QuantType = QuantType.FP,
         float_to_int_impl_type: FloatToIntImplType = FloatToIntImplType.
     ROUND,
         scaling_impl_type: ScalingImplType = ScalingImplType.PARAMETER,
         scaling_override: Optional[Module] = None,
         scaling_per_channel: bool = False,
         scaling_stats_sigma: float = 3.0,
         scaling_stats_op: StatsOp = StatsOp.MEAN_LEARN_SIGMA_STD,
         scaling_stats_buffer_momentum: float = 0.1,
         scaling_stats_permute_dims: Tuple = (1, 0, 2, 3),
         per_channel_broadcastable_shape: Optional[Tuple[int, ...]] = None,
         min_overall_bit_width: Optional[int] = 2,
         max_overall_bit_width: Optional[int] = None,
         bit_width_impl_override: Union[BitWidthParameter] = None,
         bit_width_impl_type: BitWidthImplType = BitWidthImplType.CONST,
         restrict_bit_width_type: RestrictValueType = RestrictValueType.INT,
         restrict_scaling_type: RestrictValueType = RestrictValueType.
     LOG_FP,
         scaling_min_val: Optional[float] = SCALING_MIN_VAL,
         override_pretrained_bit_width: bool = False,
         return_quant_tensor: bool = False):
     super(QuantHardTanh,
           self).__init__(return_quant_tensor=return_quant_tensor)
     if quant_type == QuantType.FP:
         activation_impl = ConstScalarClamp(min_val=min_val,
                                            max_val=max_val)
     else:
         activation_impl = Identity()
     self.act_quant_proxy = ActivationQuantProxy(
         activation_impl=activation_impl,
         bit_width=bit_width,
         signed=True,
         narrow_range=narrow_range,
         scaling_override=scaling_override,
         min_val=min_val,
         max_val=max_val,
         quant_type=quant_type,
         float_to_int_impl_type=float_to_int_impl_type,
         scaling_impl_type=scaling_impl_type,
         scaling_per_channel=scaling_per_channel,
         scaling_min_val=scaling_min_val,
         per_channel_broadcastable_shape=per_channel_broadcastable_shape,
         min_overall_bit_width=min_overall_bit_width,
         max_overall_bit_width=max_overall_bit_width,
         bit_width_impl_override=bit_width_impl_override,
         bit_width_impl_type=bit_width_impl_type,
         restrict_bit_width_type=restrict_bit_width_type,
         restrict_scaling_type=restrict_scaling_type,
         override_pretrained_bit_width=override_pretrained_bit_width,
         scaling_stats_sigma=scaling_stats_sigma,
         scaling_stats_op=scaling_stats_op,
         scaling_stats_buffer_momentum=scaling_stats_buffer_momentum,
         scaling_stats_permute_dims=scaling_stats_permute_dims)
 def __init__(
         self,
         bit_width: int,
         max_val: float,
         quant_type: QuantType = QuantType.FP,
         float_to_int_impl_type: FloatToIntImplType = FloatToIntImplType.
     ROUND,
         scaling_impl_type: ScalingImplType = ScalingImplType.PARAMETER,
         scaling_override: Optional[Module] = None,
         scaling_per_channel: bool = False,
         scaling_min_val: Optional[float] = SCALING_MIN_VAL,
         scaling_stats_sigma=2.0,
         scaling_stats_input_view_shape_impl: Optional[
             StatsInputViewShapeImpl] = StatsInputViewShapeImpl.
     OVER_OUTPUT_CHANNELS,
         scaling_stats_op=StatsOp.MEAN_LEARN_SIGMA_STD,
         scaling_stats_buffer_momentum=0.1,
         scaling_stats_permute_dims=(1, 0, 2, 3),
         per_channel_broadcastable_shape: Optional[Tuple[int, ...]] = None,
         min_overall_bit_width: Optional[int] = 2,
         max_overall_bit_width: Optional[int] = None,
         bit_width_impl_override: Union[BitWidthParameter] = None,
         bit_width_impl_type: BitWidthImplType = BitWidthImplType.CONST,
         restrict_bit_width_type: RestrictValueType = RestrictValueType.INT,
         restrict_scaling_type: RestrictValueType = RestrictValueType.
     LOG_FP,
         override_pretrained_bit_width: bool = False,
         return_quant_tensor: bool = False):
     super(QuantReLU,
           self).__init__(return_quant_tensor=return_quant_tensor)
     activation_impl = nn.ReLU()
     self.act_quant_proxy = ActivationQuantProxy(
         activation_impl=activation_impl,
         bit_width=bit_width,
         signed=False,
         narrow_range=False,
         scaling_override=scaling_override,
         min_val=0.0,
         max_val=max_val,
         quant_type=quant_type,
         float_to_int_impl_type=float_to_int_impl_type,
         scaling_impl_type=scaling_impl_type,
         scaling_per_channel=scaling_per_channel,
         scaling_min_val=scaling_min_val,
         per_channel_broadcastable_shape=per_channel_broadcastable_shape,
         min_overall_bit_width=min_overall_bit_width,
         max_overall_bit_width=max_overall_bit_width,
         bit_width_impl_override=bit_width_impl_override,
         bit_width_impl_type=bit_width_impl_type,
         restrict_bit_width_type=restrict_bit_width_type,
         restrict_scaling_type=restrict_scaling_type,
         override_pretrained_bit_width=override_pretrained_bit_width,
         scaling_stats_sigma=scaling_stats_sigma,
         scaling_stats_permute_dims=scaling_stats_permute_dims,
         scaling_stats_input_view_shape_impl=
         scaling_stats_input_view_shape_impl,
         scaling_stats_op=scaling_stats_op,
         scaling_stats_buffer_momentum=scaling_stats_buffer_momentum)
Exemple #3
0
 def __init__(
         self,
         bit_width: int,
         max_val: float,
         quant_type: QuantType = QuantType.FP,
         float_to_int_impl_type: FloatToIntImplType = FloatToIntImplType.
     ROUND,
         scaling_impl_type: ScalingImplType = ScalingImplType.PARAMETER,
         scaling_override: Optional[Module] = None,
         scaling_per_channel: bool = False,
         scaling_min_val: Optional[float] = SCALING_MIN_VAL,
         scaling_stats_sigma=2.0,
         scaling_stats_op=StatsOp.MEAN_LEARN_SIGMA_STD,
         scaling_stats_buffer_momentum=0.1,
         scaling_stats_permute_dims=(1, 0, 2, 3),
         per_channel_broadcastable_shape: Optional[Tuple[int, ...]] = None,
         min_overall_bit_width: Optional[int] = 2,
         max_overall_bit_width: Optional[int] = None,
         bit_width_impl_override: Union[BitWidthParameter] = None,
         bit_width_impl_type: BitWidthImplType = BitWidthImplType.CONST,
         restrict_bit_width_type: RestrictValueType = RestrictValueType.INT,
         restrict_scaling_type: RestrictValueType = RestrictValueType.
     LOG_FP,
         override_pretrained_bit_width: bool = False,
         return_quant_tensor: bool = False):
     super(QuantReLU,
           self).__init__(return_quant_tensor=return_quant_tensor)
     # save a copy of args passed constructor, used to determine whether
     # the quantization config is exportable to something FINN supports
     self.init_args = locals()
     activation_impl = nn.ReLU()
     self.act_quant_proxy = ActivationQuantProxy(
         activation_impl=activation_impl,
         bit_width=bit_width,
         signed=False,
         narrow_range=False,
         scaling_override=scaling_override,
         min_val=0.0,
         max_val=max_val,
         quant_type=quant_type,
         float_to_int_impl_type=float_to_int_impl_type,
         scaling_impl_type=scaling_impl_type,
         scaling_per_channel=scaling_per_channel,
         scaling_min_val=scaling_min_val,
         per_channel_broadcastable_shape=per_channel_broadcastable_shape,
         min_overall_bit_width=min_overall_bit_width,
         max_overall_bit_width=max_overall_bit_width,
         bit_width_impl_override=bit_width_impl_override,
         bit_width_impl_type=bit_width_impl_type,
         restrict_bit_width_type=restrict_bit_width_type,
         restrict_scaling_type=restrict_scaling_type,
         override_pretrained_bit_width=override_pretrained_bit_width,
         scaling_stats_sigma=scaling_stats_sigma,
         scaling_stats_permute_dims=scaling_stats_permute_dims,
         scaling_stats_op=scaling_stats_op,
         scaling_stats_buffer_momentum=scaling_stats_buffer_momentum)
 def __init__(
         self,
         bit_width: int,
         narrow_range: bool = False,
         quant_type: QuantType = QuantType.FP,
         float_to_int_impl_type: FloatToIntImplType = FloatToIntImplType.
     ROUND,
         min_overall_bit_width: Optional[int] = 2,
         max_overall_bit_width: Optional[int] = None,
         bit_width_impl_override: Union[BitWidthParameter] = None,
         bit_width_impl_type: BitWidthImplType = BitWidthImplType.CONST,
         restrict_bit_width_type: RestrictValueType = RestrictValueType.INT,
         restrict_scaling_type: RestrictValueType = RestrictValueType.
     LOG_FP,
         scaling_min_val: Optional[float] = SCALING_MIN_VAL,
         override_pretrained_bit_width: bool = False,
         return_quant_tensor: bool = False):
     super(QuantTanh,
           self).__init__(return_quant_tensor=return_quant_tensor)
     activation_impl = nn.Tanh()
     self.act_quant_proxy = ActivationQuantProxy(
         activation_impl=activation_impl,
         bit_width=bit_width,
         signed=True,
         narrow_range=narrow_range,
         scaling_override=None,
         min_val=-1.0,
         max_val=1.0,
         quant_type=quant_type,
         float_to_int_impl_type=float_to_int_impl_type,
         scaling_impl_type=ScalingImplType.CONST,
         scaling_per_channel=False,
         scaling_min_val=scaling_min_val,
         per_channel_broadcastable_shape=None,
         min_overall_bit_width=min_overall_bit_width,
         max_overall_bit_width=max_overall_bit_width,
         bit_width_impl_override=bit_width_impl_override,
         bit_width_impl_type=bit_width_impl_type,
         restrict_bit_width_type=restrict_bit_width_type,
         restrict_scaling_type=restrict_scaling_type,
         override_pretrained_bit_width=override_pretrained_bit_width,
         scaling_stats_sigma=None,
         scaling_stats_input_view_shape_impl=None,
         scaling_stats_op=None,
         scaling_stats_buffer_momentum=None,
         scaling_stats_permute_dims=None)