Exemplo n.º 1
0
 def __init__(self,
              collect_stats_steps: int,
              scaling_stats_impl: Module,
              scaling_stats_input_view_shape_impl:
              Module = OverBatchOverTensorView(),
              scaling_shape: Tuple[int, ...] = SCALAR_SHAPE,
              restrict_scaling_impl: Optional[Module] = None,
              scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM,
              scaling_min_val: Optional[float] = None) -> None:
     super(ParameterFromRuntimeStatsScaling, self).__init__()
     assert collect_stats_steps > 0, 'Steps should be more than 0'
     self.collect_stats_steps = collect_stats_steps
     self.counter: int = brevitas.jit.Attribute(0, int)
     self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
     self.stats = _Stats(scaling_stats_impl, scaling_shape)
     self.momentum = scaling_stats_momentum
     self.register_buffer('buffer', torch.full(scaling_shape, 1.0))
     self.value = Parameter(torch.full(scaling_shape, 1.0))
     self.restrict_clamp_scaling = _RestrictClampValue(
         scaling_min_val, restrict_scaling_impl)
     if restrict_scaling_impl is not None:
         self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module(
         )
         self.restrict_preprocess = restrict_scaling_impl.restrict_init_module(
         )
     else:
         self.restrict_inplace_preprocess = Identity()
         self.restrict_preprocess = Identity()
Exemplo n.º 2
0
    def __init__(self,
                 scaling_init: Union[float, Tensor],
                 scaling_shape: Optional[Tuple[int, ...]] = None,
                 restrict_scaling_impl: Optional[Module] = None,
                 scaling_min_val: Optional[float] = None) -> None:
        super(ParameterScaling, self).__init__()

        if (isinstance(scaling_init, Tensor) and scaling_shape is not None
                and scaling_init.shape != SCALAR_SHAPE
                and scaling_init.shape != scaling_shape):
            raise RuntimeError(
                "scaling_init.shape is non-scalar and != from scaling_shape.")

        if isinstance(scaling_init, Tensor):
            scaling_init = scaling_init.detach()
        else:
            scaling_init = torch.tensor(scaling_init)
        if restrict_scaling_impl is not None:
            scaling_init = restrict_scaling_impl.restrict_init_tensor(
                scaling_init)
        if scaling_init.shape == SCALAR_SHAPE and scaling_shape is not None:
            scaling_init = torch.full(scaling_shape, scaling_init)
        self.value = Parameter(scaling_init)
        self.restrict_clamp_scaling = _RestrictClampValue(
            scaling_min_val, restrict_scaling_impl)
Exemplo n.º 3
0
 def __init__(self,
              collect_stats_steps: int,
              scaling_stats_impl: Module,
              scaling_stats_input_view_shape_impl:
              Module = OverBatchOverTensorView(),
              scaling_shape: Tuple[int, ...] = SCALAR_SHAPE,
              restrict_scaling_impl: Optional[Module] = None,
              scaling_stats_permute_dims: Optional[Tuple[int, ...]] = None,
              scaling_stats_momentum: float = DEFAULT_MOMENTUM,
              scaling_min_val: Optional[float] = None) -> None:
     super(ParameterFromRuntimeStatsScaling, self).__init__()
     assert collect_stats_steps > 0, 'Steps should be more than 0'
     if config.JIT_ENABLED:
         warnings.warn(
             'BREVITAS_JIT=1 on ParameterFromRuntimeStatsScaling could result in numerical '
             'errors. Disabling it is highly recommended unless you are resuming from a previous'
             'quantized checkpoint (not a floating-point one).')
     if scaling_shape != SCALAR_SHAPE and scaling_stats_permute_dims is None:
         raise RuntimeError(
             "Per channel runtime stats require a permute shape")
     self.collect_stats_steps = collect_stats_steps
     self.counter: int = brevitas.jit.Attribute(0, int)
     self.stats_permute_dims = scaling_stats_permute_dims
     self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
     self.stats = _Stats(scaling_stats_impl, scaling_shape)
     self.momentum = scaling_stats_momentum
     self.value = Parameter(torch.full(scaling_shape, 1.0))
     self.restrict_clamp_scaling = _RestrictClampValue(
         scaling_min_val, restrict_scaling_impl)
     if restrict_scaling_impl is not None:
         self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module(
         )
     else:
         self.restrict_inplace_preprocess = InplaceNoOp()
Exemplo n.º 4
0
 def __init__(self,
              collect_stats_steps: int,
              restrict_scaling_impl: Module,
              scaling_stats_impl: Module,
              scaling_shape: Tuple[int, ...],
              scaling_stats_input_view_shape_impl: Module,
              scaling_stats_permute_dims: Optional[Tuple[int, ...]] = None,
              scaling_stats_momentum: float = DEFAULT_MOMENTUM,
              scaling_min_val: Optional[float] = None) -> None:
     super(ParameterFromRuntimeStatsScaling, self).__init__()
     assert collect_stats_steps > 0, 'Steps should be more than 0'
     if scaling_shape != SCALAR_SHAPE and scaling_stats_permute_dims is None:
         raise RuntimeError(
             "Per channel runtime stats require a permute shape")
     self.collect_stats_steps = collect_stats_steps
     self.counter: int = brevitas.jit.Attribute(0, int)
     self.stats_permute_dims = scaling_stats_permute_dims
     self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
     self.stats = _Stats(scaling_stats_impl, scaling_shape)
     self.momentum = scaling_stats_momentum
     self.value = Parameter(torch.full(scaling_shape, 1.0))
     self.restrict_clamp_scaling = _RestrictClampValue(
         scaling_min_val, restrict_scaling_impl)
     self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module(
     )
Exemplo n.º 5
0
    def __init__(
            self,
            restrict_scaling_impl: Module,
            scaling_shape: Tuple[int, ...],
            scaling_min_val: Optional[float] = None,
            affine_rescaling: bool = False) -> None:
        super(_StatsScaling, self).__init__()

        if affine_rescaling:
            self.affine_rescaling = _AffineRescaling(scaling_shape)
        else:
            self.affine_rescaling = Identity()
        self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
        self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
Exemplo n.º 6
0
 def __init__(self,
              scaling_init: Union[float, Tensor],
              restrict_scaling_impl: Module,
              scaling_min_val: Optional[float] = None) -> None:
     super(ConstScaling, self).__init__()
     self.restrict_clamp_scaling = _RestrictClampValue(
         scaling_min_val, restrict_scaling_impl)
     if isinstance(scaling_init, Tensor):
         scaling_init = restrict_scaling_impl.restrict_init_tensor(
             scaling_init)
         self.value = StatelessBuffer(scaling_init.detach())
     else:
         scaling_init = restrict_scaling_impl.restrict_init_float(
             scaling_init)
         self.value = StatelessBuffer(torch.tensor(scaling_init))
Exemplo n.º 7
0
    def __init__(self,
                 scaling_init: Union[float, Tensor],
                 scaling_shape: Tuple[int, ...],
                 restrict_scaling_impl: Module,
                 scaling_min_val: Optional[float] = None) -> None:
        super(ParameterScaling, self).__init__()

        if isinstance(scaling_init, Tensor):
            scaling_init = scaling_init.detach()
        else:
            self.value = torch.tensor(scaling_init)
        self.restrict_clamp_scaling = _RestrictClampValue(
            scaling_min_val, restrict_scaling_impl)
        scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init)
        if scaling_init.dim() == 0:
            self.value = Parameter(torch.full(scaling_shape, scaling_init))
        else:
            assert scaling_init.shape == scaling_shape
            self.value = Parameter(scaling_init)