def __init__(self, collect_stats_steps: int, scaling_stats_impl: Module, scaling_stats_input_view_shape_impl: Module = OverBatchOverTensorView(), scaling_shape: Tuple[int, ...] = SCALAR_SHAPE, restrict_scaling_impl: Optional[Module] = None, scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None) -> None: super(ParameterFromRuntimeStatsScaling, self).__init__() assert collect_stats_steps > 0, 'Steps should be more than 0' self.collect_stats_steps = collect_stats_steps self.counter: int = brevitas.jit.Attribute(0, int) self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl self.stats = _Stats(scaling_stats_impl, scaling_shape) self.momentum = scaling_stats_momentum self.register_buffer('buffer', torch.full(scaling_shape, 1.0)) self.value = Parameter(torch.full(scaling_shape, 1.0)) self.restrict_clamp_scaling = _RestrictClampValue( scaling_min_val, restrict_scaling_impl) if restrict_scaling_impl is not None: self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module( ) self.restrict_preprocess = restrict_scaling_impl.restrict_init_module( ) else: self.restrict_inplace_preprocess = Identity() self.restrict_preprocess = Identity()
def __init__(self, scaling_min_val: Optional[float], restrict_value_impl: Optional[Module]): super(_RestrictClampValue, self).__init__() if scaling_min_val is not None and scaling_min_val != 0: if restrict_value_impl is not None: scaling_min_val = restrict_value_impl.restrict_init_float( scaling_min_val) self.clamp_min_ste = ScalarClampMinSte(scaling_min_val) else: self.clamp_min_ste = Identity() if restrict_value_impl is not None: self.restrict_value_impl = restrict_value_impl else: self.restrict_value_impl = Identity()
def __init__(self, stats_op: StatsOp, restrict_scaling_type: RestrictValueType, stats_output_shape: Tuple[int, ...], scaling_min_val: Optional[float], affine: bool) -> None: super(StatsScaling, self).__init__() if not (restrict_scaling_type == RestrictValueType.FP or restrict_scaling_type == RestrictValueType.LOG_FP or restrict_scaling_type == RestrictValueType.POWER_OF_TWO): raise Exception( "Restriction of type {} is not supported for stats scaling.". format(str(restrict_scaling_type))) if stats_op == StatsOp.MAX_AVE and stats_output_shape != SCALING_SCALAR_SHAPE: raise Exception( "Scaling with MAX_AVE stats can't be over output channels.") if affine: self.affine_rescaling = AffineRescaling(stats_output_shape) else: self.affine_rescaling = Identity() self.restrict_scaling = RestrictValue(restrict_scaling_type, FloatToIntImplType.CEIL, scaling_min_val) self.restrict_scaling_preprocess = RestrictValue.restrict_value_op( restrict_scaling_type, restrict_value_op_impl_type=RestrictValueOpImplType.TORCH_MODULE)
def __init__( self, bit_width: int, min_val: float = -1.0, max_val: float = 1.0, narrow_range: bool = False, quant_type: QuantType = QuantType.FP, float_to_int_impl_type: FloatToIntImplType = FloatToIntImplType. ROUND, scaling_impl_type: ScalingImplType = ScalingImplType.PARAMETER, scaling_override: Optional[Module] = None, scaling_per_channel: bool = False, scaling_stats_sigma: float = 3.0, scaling_stats_op: StatsOp = StatsOp.MEAN_LEARN_SIGMA_STD, scaling_stats_buffer_momentum: float = 0.1, scaling_stats_permute_dims: Tuple = (1, 0, 2, 3), per_channel_broadcastable_shape: Optional[Tuple[int, ...]] = None, min_overall_bit_width: Optional[int] = 2, max_overall_bit_width: Optional[int] = None, bit_width_impl_override: Union[BitWidthParameter] = None, bit_width_impl_type: BitWidthImplType = BitWidthImplType.CONST, restrict_bit_width_type: RestrictValueType = RestrictValueType.INT, restrict_scaling_type: RestrictValueType = RestrictValueType. LOG_FP, scaling_min_val: Optional[float] = SCALING_MIN_VAL, override_pretrained_bit_width: bool = False, return_quant_tensor: bool = False): super(QuantHardTanh, self).__init__(return_quant_tensor=return_quant_tensor) if quant_type == QuantType.FP: activation_impl = ConstScalarClamp(min_val=min_val, max_val=max_val) else: activation_impl = Identity() self.act_quant_proxy = ActivationQuantProxy( activation_impl=activation_impl, bit_width=bit_width, signed=True, narrow_range=narrow_range, scaling_override=scaling_override, min_val=min_val, max_val=max_val, quant_type=quant_type, float_to_int_impl_type=float_to_int_impl_type, scaling_impl_type=scaling_impl_type, scaling_per_channel=scaling_per_channel, scaling_min_val=scaling_min_val, per_channel_broadcastable_shape=per_channel_broadcastable_shape, min_overall_bit_width=min_overall_bit_width, max_overall_bit_width=max_overall_bit_width, bit_width_impl_override=bit_width_impl_override, bit_width_impl_type=bit_width_impl_type, restrict_bit_width_type=restrict_bit_width_type, restrict_scaling_type=restrict_scaling_type, override_pretrained_bit_width=override_pretrained_bit_width, scaling_stats_sigma=scaling_stats_sigma, scaling_stats_op=scaling_stats_op, scaling_stats_buffer_momentum=scaling_stats_buffer_momentum, scaling_stats_permute_dims=scaling_stats_permute_dims)
def __init__( self, restrict_scaling_impl: Module, scaling_shape: Tuple[int, ...], scaling_min_val: Optional[float] = None, affine_rescaling: bool = False) -> None: super(_StatsScaling, self).__init__() if affine_rescaling: self.affine_rescaling = _AffineRescaling(scaling_shape) else: self.affine_rescaling = Identity() self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
def __init__(self, permute_dims: Optional[Tuple[int, ...]]) -> None: super(OverOutputChannelView, self).__init__() if permute_dims is not None: self.permute_impl = PermuteDims(permute_dims) else: self.permute_impl = Identity()
def restrict_init_inplace_module(self): return Identity()