def __init__( self, bit_width: int, min_val: float = -1.0, max_val: float = 1.0, narrow_range: bool = False, quant_type: QuantType = QuantType.FP, float_to_int_impl_type: FloatToIntImplType = FloatToIntImplType. ROUND, scaling_impl_type: ScalingImplType = ScalingImplType.PARAMETER, scaling_override: Optional[Module] = None, scaling_per_channel: bool = False, scaling_stats_sigma: float = 3.0, scaling_stats_op: StatsOp = StatsOp.MEAN_LEARN_SIGMA_STD, scaling_stats_buffer_momentum: float = 0.1, scaling_stats_permute_dims: Tuple = (1, 0, 2, 3), per_channel_broadcastable_shape: Optional[Tuple[int, ...]] = None, min_overall_bit_width: Optional[int] = 2, max_overall_bit_width: Optional[int] = None, bit_width_impl_override: Union[BitWidthParameter] = None, bit_width_impl_type: BitWidthImplType = BitWidthImplType.CONST, restrict_bit_width_type: RestrictValueType = RestrictValueType.INT, restrict_scaling_type: RestrictValueType = RestrictValueType. LOG_FP, scaling_min_val: Optional[float] = SCALING_MIN_VAL, override_pretrained_bit_width: bool = False, return_quant_tensor: bool = False): super(QuantHardTanh, self).__init__(return_quant_tensor=return_quant_tensor) if quant_type == QuantType.FP: activation_impl = ConstScalarClamp(min_val=min_val, max_val=max_val) else: activation_impl = Identity() self.act_quant_proxy = ActivationQuantProxy( activation_impl=activation_impl, bit_width=bit_width, signed=True, narrow_range=narrow_range, scaling_override=scaling_override, min_val=min_val, max_val=max_val, quant_type=quant_type, float_to_int_impl_type=float_to_int_impl_type, scaling_impl_type=scaling_impl_type, scaling_per_channel=scaling_per_channel, scaling_min_val=scaling_min_val, per_channel_broadcastable_shape=per_channel_broadcastable_shape, min_overall_bit_width=min_overall_bit_width, max_overall_bit_width=max_overall_bit_width, bit_width_impl_override=bit_width_impl_override, bit_width_impl_type=bit_width_impl_type, restrict_bit_width_type=restrict_bit_width_type, restrict_scaling_type=restrict_scaling_type, override_pretrained_bit_width=override_pretrained_bit_width, scaling_stats_sigma=scaling_stats_sigma, scaling_stats_op=scaling_stats_op, scaling_stats_buffer_momentum=scaling_stats_buffer_momentum, scaling_stats_permute_dims=scaling_stats_permute_dims)
def __init__( self, bit_width: int, max_val: float, quant_type: QuantType = QuantType.FP, float_to_int_impl_type: FloatToIntImplType = FloatToIntImplType. ROUND, scaling_impl_type: ScalingImplType = ScalingImplType.PARAMETER, scaling_override: Optional[Module] = None, scaling_per_channel: bool = False, scaling_min_val: Optional[float] = SCALING_MIN_VAL, scaling_stats_sigma=2.0, scaling_stats_input_view_shape_impl: Optional[ StatsInputViewShapeImpl] = StatsInputViewShapeImpl. OVER_OUTPUT_CHANNELS, scaling_stats_op=StatsOp.MEAN_LEARN_SIGMA_STD, scaling_stats_buffer_momentum=0.1, scaling_stats_permute_dims=(1, 0, 2, 3), per_channel_broadcastable_shape: Optional[Tuple[int, ...]] = None, min_overall_bit_width: Optional[int] = 2, max_overall_bit_width: Optional[int] = None, bit_width_impl_override: Union[BitWidthParameter] = None, bit_width_impl_type: BitWidthImplType = BitWidthImplType.CONST, restrict_bit_width_type: RestrictValueType = RestrictValueType.INT, restrict_scaling_type: RestrictValueType = RestrictValueType. LOG_FP, override_pretrained_bit_width: bool = False, return_quant_tensor: bool = False): super(QuantReLU, self).__init__(return_quant_tensor=return_quant_tensor) activation_impl = nn.ReLU() self.act_quant_proxy = ActivationQuantProxy( activation_impl=activation_impl, bit_width=bit_width, signed=False, narrow_range=False, scaling_override=scaling_override, min_val=0.0, max_val=max_val, quant_type=quant_type, float_to_int_impl_type=float_to_int_impl_type, scaling_impl_type=scaling_impl_type, scaling_per_channel=scaling_per_channel, scaling_min_val=scaling_min_val, per_channel_broadcastable_shape=per_channel_broadcastable_shape, min_overall_bit_width=min_overall_bit_width, max_overall_bit_width=max_overall_bit_width, bit_width_impl_override=bit_width_impl_override, bit_width_impl_type=bit_width_impl_type, restrict_bit_width_type=restrict_bit_width_type, restrict_scaling_type=restrict_scaling_type, override_pretrained_bit_width=override_pretrained_bit_width, scaling_stats_sigma=scaling_stats_sigma, scaling_stats_permute_dims=scaling_stats_permute_dims, scaling_stats_input_view_shape_impl= scaling_stats_input_view_shape_impl, scaling_stats_op=scaling_stats_op, scaling_stats_buffer_momentum=scaling_stats_buffer_momentum)
def __init__( self, bit_width: int, max_val: float, quant_type: QuantType = QuantType.FP, float_to_int_impl_type: FloatToIntImplType = FloatToIntImplType. ROUND, scaling_impl_type: ScalingImplType = ScalingImplType.PARAMETER, scaling_override: Optional[Module] = None, scaling_per_channel: bool = False, scaling_min_val: Optional[float] = SCALING_MIN_VAL, scaling_stats_sigma=2.0, scaling_stats_op=StatsOp.MEAN_LEARN_SIGMA_STD, scaling_stats_buffer_momentum=0.1, scaling_stats_permute_dims=(1, 0, 2, 3), per_channel_broadcastable_shape: Optional[Tuple[int, ...]] = None, min_overall_bit_width: Optional[int] = 2, max_overall_bit_width: Optional[int] = None, bit_width_impl_override: Union[BitWidthParameter] = None, bit_width_impl_type: BitWidthImplType = BitWidthImplType.CONST, restrict_bit_width_type: RestrictValueType = RestrictValueType.INT, restrict_scaling_type: RestrictValueType = RestrictValueType. LOG_FP, override_pretrained_bit_width: bool = False, return_quant_tensor: bool = False): super(QuantReLU, self).__init__(return_quant_tensor=return_quant_tensor) # save a copy of args passed constructor, used to determine whether # the quantization config is exportable to something FINN supports self.init_args = locals() activation_impl = nn.ReLU() self.act_quant_proxy = ActivationQuantProxy( activation_impl=activation_impl, bit_width=bit_width, signed=False, narrow_range=False, scaling_override=scaling_override, min_val=0.0, max_val=max_val, quant_type=quant_type, float_to_int_impl_type=float_to_int_impl_type, scaling_impl_type=scaling_impl_type, scaling_per_channel=scaling_per_channel, scaling_min_val=scaling_min_val, per_channel_broadcastable_shape=per_channel_broadcastable_shape, min_overall_bit_width=min_overall_bit_width, max_overall_bit_width=max_overall_bit_width, bit_width_impl_override=bit_width_impl_override, bit_width_impl_type=bit_width_impl_type, restrict_bit_width_type=restrict_bit_width_type, restrict_scaling_type=restrict_scaling_type, override_pretrained_bit_width=override_pretrained_bit_width, scaling_stats_sigma=scaling_stats_sigma, scaling_stats_permute_dims=scaling_stats_permute_dims, scaling_stats_op=scaling_stats_op, scaling_stats_buffer_momentum=scaling_stats_buffer_momentum)
def __init__( self, bit_width: int, narrow_range: bool = False, quant_type: QuantType = QuantType.FP, float_to_int_impl_type: FloatToIntImplType = FloatToIntImplType. ROUND, min_overall_bit_width: Optional[int] = 2, max_overall_bit_width: Optional[int] = None, bit_width_impl_override: Union[BitWidthParameter] = None, bit_width_impl_type: BitWidthImplType = BitWidthImplType.CONST, restrict_bit_width_type: RestrictValueType = RestrictValueType.INT, restrict_scaling_type: RestrictValueType = RestrictValueType. LOG_FP, scaling_min_val: Optional[float] = SCALING_MIN_VAL, override_pretrained_bit_width: bool = False, return_quant_tensor: bool = False): super(QuantTanh, self).__init__(return_quant_tensor=return_quant_tensor) activation_impl = nn.Tanh() self.act_quant_proxy = ActivationQuantProxy( activation_impl=activation_impl, bit_width=bit_width, signed=True, narrow_range=narrow_range, scaling_override=None, min_val=-1.0, max_val=1.0, quant_type=quant_type, float_to_int_impl_type=float_to_int_impl_type, scaling_impl_type=ScalingImplType.CONST, scaling_per_channel=False, scaling_min_val=scaling_min_val, per_channel_broadcastable_shape=None, min_overall_bit_width=min_overall_bit_width, max_overall_bit_width=max_overall_bit_width, bit_width_impl_override=bit_width_impl_override, bit_width_impl_type=bit_width_impl_type, restrict_bit_width_type=restrict_bit_width_type, restrict_scaling_type=restrict_scaling_type, override_pretrained_bit_width=override_pretrained_bit_width, scaling_stats_sigma=None, scaling_stats_input_view_shape_impl=None, scaling_stats_op=None, scaling_stats_buffer_momentum=None, scaling_stats_permute_dims=None)