예제 #1
0
 def __init__(self,
              collect_stats_steps: int,
              scaling_stats_impl: Module,
              scaling_stats_input_view_shape_impl:
              Module = OverBatchOverTensorView(),
              scaling_shape: Tuple[int, ...] = SCALAR_SHAPE,
              restrict_scaling_impl: Optional[Module] = None,
              scaling_stats_permute_dims: Optional[Tuple[int, ...]] = None,
              scaling_stats_momentum: float = DEFAULT_MOMENTUM,
              scaling_min_val: Optional[float] = None) -> None:
     super(ParameterFromRuntimeStatsScaling, self).__init__()
     assert collect_stats_steps > 0, 'Steps should be more than 0'
     if config.JIT_ENABLED:
         warnings.warn(
             'BREVITAS_JIT=1 on ParameterFromRuntimeStatsScaling could result in numerical '
             'errors. Disabling it is highly recommended unless you are resuming from a previous'
             'quantized checkpoint (not a floating-point one).')
     if scaling_shape != SCALAR_SHAPE and scaling_stats_permute_dims is None:
         raise RuntimeError(
             "Per channel runtime stats require a permute shape")
     self.collect_stats_steps = collect_stats_steps
     self.counter: int = brevitas.jit.Attribute(0, int)
     self.stats_permute_dims = scaling_stats_permute_dims
     self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
     self.stats = _Stats(scaling_stats_impl, scaling_shape)
     self.momentum = scaling_stats_momentum
     self.value = Parameter(torch.full(scaling_shape, 1.0))
     self.restrict_clamp_scaling = _RestrictClampValue(
         scaling_min_val, restrict_scaling_impl)
     if restrict_scaling_impl is not None:
         self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module(
         )
     else:
         self.restrict_inplace_preprocess = InplaceNoOp()
예제 #2
0
 def __init__(self,
              collect_stats_steps: int,
              scaling_stats_impl: Module,
              scaling_stats_input_view_shape_impl:
              Module = OverBatchOverTensorView(),
              scaling_shape: Tuple[int, ...] = SCALAR_SHAPE,
              restrict_scaling_impl: Optional[Module] = None,
              scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM,
              scaling_min_val: Optional[float] = None) -> None:
     super(ParameterFromRuntimeStatsScaling, self).__init__()
     assert collect_stats_steps > 0, 'Steps should be more than 0'
     self.collect_stats_steps = collect_stats_steps
     self.counter: int = brevitas.jit.Attribute(0, int)
     self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
     self.stats = _Stats(scaling_stats_impl, scaling_shape)
     self.momentum = scaling_stats_momentum
     self.register_buffer('buffer', torch.full(scaling_shape, 1.0))
     self.value = Parameter(torch.full(scaling_shape, 1.0))
     self.restrict_clamp_scaling = _RestrictClampValue(
         scaling_min_val, restrict_scaling_impl)
     if restrict_scaling_impl is not None:
         self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module(
         )
         self.restrict_preprocess = restrict_scaling_impl.restrict_init_module(
         )
     else:
         self.restrict_inplace_preprocess = Identity()
         self.restrict_preprocess = Identity()
예제 #3
0
 def __init__(self,
              collect_stats_steps: int,
              restrict_scaling_impl: Module,
              scaling_stats_impl: Module,
              scaling_shape: Tuple[int, ...],
              scaling_stats_input_view_shape_impl: Module,
              scaling_stats_permute_dims: Optional[Tuple[int, ...]] = None,
              scaling_stats_momentum: float = DEFAULT_MOMENTUM,
              scaling_min_val: Optional[float] = None) -> None:
     super(ParameterFromRuntimeStatsScaling, self).__init__()
     assert collect_stats_steps > 0, 'Steps should be more than 0'
     if scaling_shape != SCALAR_SHAPE and scaling_stats_permute_dims is None:
         raise RuntimeError(
             "Per channel runtime stats require a permute shape")
     self.collect_stats_steps = collect_stats_steps
     self.counter: int = brevitas.jit.Attribute(0, int)
     self.stats_permute_dims = scaling_stats_permute_dims
     self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
     self.stats = _Stats(scaling_stats_impl, scaling_shape)
     self.momentum = scaling_stats_momentum
     self.value = Parameter(torch.full(scaling_shape, 1.0))
     self.restrict_clamp_scaling = _RestrictClampValue(
         scaling_min_val, restrict_scaling_impl)
     self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module(
     )