Exemplo n.º 1
0
 def __init__(self,
              collect_stats_steps: int,
              scaling_stats_impl: Module,
              scaling_stats_input_view_shape_impl:
              Module = OverBatchOverTensorView(),
              scaling_shape: Tuple[int, ...] = SCALAR_SHAPE,
              restrict_scaling_impl: Optional[Module] = None,
              scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM,
              scaling_min_val: Optional[float] = None) -> None:
     super(ParameterFromRuntimeStatsScaling, self).__init__()
     assert collect_stats_steps > 0, 'Steps should be more than 0'
     self.collect_stats_steps = collect_stats_steps
     self.counter: int = brevitas.jit.Attribute(0, int)
     self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
     self.stats = _Stats(scaling_stats_impl, scaling_shape)
     self.momentum = scaling_stats_momentum
     self.register_buffer('buffer', torch.full(scaling_shape, 1.0))
     self.value = Parameter(torch.full(scaling_shape, 1.0))
     self.restrict_clamp_scaling = _RestrictClampValue(
         scaling_min_val, restrict_scaling_impl)
     if restrict_scaling_impl is not None:
         self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module(
         )
         self.restrict_preprocess = restrict_scaling_impl.restrict_init_module(
         )
     else:
         self.restrict_inplace_preprocess = Identity()
         self.restrict_preprocess = Identity()
Exemplo n.º 2
0
 def __init__(self, scaling_min_val: Optional[float],
              restrict_value_impl: Optional[Module]):
     super(_RestrictClampValue, self).__init__()
     if scaling_min_val is not None and scaling_min_val != 0:
         if restrict_value_impl is not None:
             scaling_min_val = restrict_value_impl.restrict_init_float(
                 scaling_min_val)
         self.clamp_min_ste = ScalarClampMinSte(scaling_min_val)
     else:
         self.clamp_min_ste = Identity()
     if restrict_value_impl is not None:
         self.restrict_value_impl = restrict_value_impl
     else:
         self.restrict_value_impl = Identity()
Exemplo n.º 3
0
    def __init__(self, stats_op: StatsOp,
                 restrict_scaling_type: RestrictValueType,
                 stats_output_shape: Tuple[int, ...],
                 scaling_min_val: Optional[float], affine: bool) -> None:
        super(StatsScaling, self).__init__()

        if not (restrict_scaling_type == RestrictValueType.FP
                or restrict_scaling_type == RestrictValueType.LOG_FP
                or restrict_scaling_type == RestrictValueType.POWER_OF_TWO):
            raise Exception(
                "Restriction of type {} is not supported for stats scaling.".
                format(str(restrict_scaling_type)))
        if stats_op == StatsOp.MAX_AVE and stats_output_shape != SCALING_SCALAR_SHAPE:
            raise Exception(
                "Scaling with MAX_AVE stats can't be over output channels.")

        if affine:
            self.affine_rescaling = AffineRescaling(stats_output_shape)
        else:
            self.affine_rescaling = Identity()

        self.restrict_scaling = RestrictValue(restrict_scaling_type,
                                              FloatToIntImplType.CEIL,
                                              scaling_min_val)
        self.restrict_scaling_preprocess = RestrictValue.restrict_value_op(
            restrict_scaling_type,
            restrict_value_op_impl_type=RestrictValueOpImplType.TORCH_MODULE)
Exemplo n.º 4
0
 def __init__(
         self,
         bit_width: int,
         min_val: float = -1.0,
         max_val: float = 1.0,
         narrow_range: bool = False,
         quant_type: QuantType = QuantType.FP,
         float_to_int_impl_type: FloatToIntImplType = FloatToIntImplType.
     ROUND,
         scaling_impl_type: ScalingImplType = ScalingImplType.PARAMETER,
         scaling_override: Optional[Module] = None,
         scaling_per_channel: bool = False,
         scaling_stats_sigma: float = 3.0,
         scaling_stats_op: StatsOp = StatsOp.MEAN_LEARN_SIGMA_STD,
         scaling_stats_buffer_momentum: float = 0.1,
         scaling_stats_permute_dims: Tuple = (1, 0, 2, 3),
         per_channel_broadcastable_shape: Optional[Tuple[int, ...]] = None,
         min_overall_bit_width: Optional[int] = 2,
         max_overall_bit_width: Optional[int] = None,
         bit_width_impl_override: Union[BitWidthParameter] = None,
         bit_width_impl_type: BitWidthImplType = BitWidthImplType.CONST,
         restrict_bit_width_type: RestrictValueType = RestrictValueType.INT,
         restrict_scaling_type: RestrictValueType = RestrictValueType.
     LOG_FP,
         scaling_min_val: Optional[float] = SCALING_MIN_VAL,
         override_pretrained_bit_width: bool = False,
         return_quant_tensor: bool = False):
     super(QuantHardTanh,
           self).__init__(return_quant_tensor=return_quant_tensor)
     if quant_type == QuantType.FP:
         activation_impl = ConstScalarClamp(min_val=min_val,
                                            max_val=max_val)
     else:
         activation_impl = Identity()
     self.act_quant_proxy = ActivationQuantProxy(
         activation_impl=activation_impl,
         bit_width=bit_width,
         signed=True,
         narrow_range=narrow_range,
         scaling_override=scaling_override,
         min_val=min_val,
         max_val=max_val,
         quant_type=quant_type,
         float_to_int_impl_type=float_to_int_impl_type,
         scaling_impl_type=scaling_impl_type,
         scaling_per_channel=scaling_per_channel,
         scaling_min_val=scaling_min_val,
         per_channel_broadcastable_shape=per_channel_broadcastable_shape,
         min_overall_bit_width=min_overall_bit_width,
         max_overall_bit_width=max_overall_bit_width,
         bit_width_impl_override=bit_width_impl_override,
         bit_width_impl_type=bit_width_impl_type,
         restrict_bit_width_type=restrict_bit_width_type,
         restrict_scaling_type=restrict_scaling_type,
         override_pretrained_bit_width=override_pretrained_bit_width,
         scaling_stats_sigma=scaling_stats_sigma,
         scaling_stats_op=scaling_stats_op,
         scaling_stats_buffer_momentum=scaling_stats_buffer_momentum,
         scaling_stats_permute_dims=scaling_stats_permute_dims)
Exemplo n.º 5
0
    def __init__(
            self,
            restrict_scaling_impl: Module,
            scaling_shape: Tuple[int, ...],
            scaling_min_val: Optional[float] = None,
            affine_rescaling: bool = False) -> None:
        super(_StatsScaling, self).__init__()

        if affine_rescaling:
            self.affine_rescaling = _AffineRescaling(scaling_shape)
        else:
            self.affine_rescaling = Identity()
        self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
        self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
Exemplo n.º 6
0
 def __init__(self, permute_dims: Optional[Tuple[int, ...]]) -> None:
     super(OverOutputChannelView, self).__init__()
     if permute_dims is not None:
         self.permute_impl = PermuteDims(permute_dims)
     else:
         self.permute_impl = Identity()
Exemplo n.º 7
0
 def restrict_init_inplace_module(self):
     return Identity()