def __init__(self,
                 quant_type: QuantType,
                 bit_width: Optional[int],
                 narrow_range: bool) -> None:
        super(BiasQuantProxy, self).__init__()

        if quant_type == QuantType.FP:
            self.tensor_quant = None
        elif quant_type == QuantType.INT:
            tensor_clamp_impl = TensorClamp()
            float_to_int_impl = RestrictValue(restrict_value_type=RestrictValueType.INT,
                                              float_to_int_impl_type=FloatToIntImplType.ROUND,
                                              min_val=None)
            if bit_width is not None:
                bit_width_impl = BitWidthConst(bit_width, restrict_bit_width_type=RestrictValueType.INT)
                self.tensor_quant = PrescaledRestrictIntQuant(narrow_range=narrow_range,
                                                              signed=True,
                                                              tensor_clamp_impl=tensor_clamp_impl,
                                                              msb_clamp_bit_width_impl=bit_width_impl,
                                                              float_to_int_impl=float_to_int_impl)
                self.requires_input_bit_width = False
            else:
                self.tensor_quant = PrescaledIntQuant(narrow_range=narrow_range,
                                                      signed=True,
                                                      tensor_clamp_impl=tensor_clamp_impl,
                                                      float_to_int_impl=float_to_int_impl)
                self.requires_input_bit_width = True
        else:
            raise Exception('Quantization type {} not supported for bias quant.'
                            .format(str(quant_type)))
Exemplo n.º 2
0
def test_RescalingIntQuant(x, narrow_range, signed, scale, int_scale,
                           bit_width):
    value = torch.tensor(x)
    scale = torch.tensor(scale)
    bit_width = torch.tensor(bit_width, dtype=torch.float)
    int_scale = torch.tensor(int_scale, dtype=torch.float)
    tensor_clamp_impl = TensorClamp()
    msb_clamp_bitwidth_mock = Mock()
    msb_clamp_bitwidth_mock.return_value = bit_width
    float_to_int_impl_mock = Mock()
    float_to_int_impl_mock.side_effect = (lambda y: y)
    int_scaling_impl_mock = Mock()
    int_scaling_impl_mock.return_value = int_scale
    scaling_impl = Mock()
    scaling_impl.return_value = scale

    int_quant = IntQuant(signed=signed,
                         narrow_range=narrow_range,
                         tensor_clamp_impl=tensor_clamp_impl,
                         float_to_int_impl=float_to_int_impl_mock)
    obj = RescalingIntQuant(int_scaling_impl=int_scaling_impl_mock,
                            bit_width_impl=msb_clamp_bitwidth_mock,
                            scaling_impl=scaling_impl,
                            int_quant=int_quant)

    output, scale_out, bit_width = obj(value)

    expected_output = int_quant(scale, int_scale, bit_width, value)
    expected_scale = scale / int_scale
    assert torch.allclose(expected_output, output, RTOL, ATOL)
    assert torch.allclose(expected_scale, scale_out, RTOL, ATOL)
Exemplo n.º 3
0
def test_IntQuant(x, narrow_range, signed, bit_width, scale, int_scale,
                  float_to_int_impl, scale_multiplier):
    float_to_int_impl_mock = Mock()
    tensor_clamp_impl = TensorClamp()

    value = torch.tensor(x)
    bit_width = torch.tensor(bit_width, dtype=torch.float)
    scale = torch.tensor(scale)
    int_scale = torch.tensor(int_scale)

    tol = scale * scale_multiplier
    float_to_int_impl_mock.side_effect = float_to_int_impl()

    obj = IntQuant(narrow_range=narrow_range,
                   signed=signed,
                   float_to_int_impl=float_to_int_impl_mock,
                   tensor_clamp_impl=tensor_clamp_impl)
    output = obj(scale, int_scale, bit_width, value)

    min_value = int(min_int(signed, narrow_range, bit_width))
    max_value = int(max_int(signed, bit_width))
    admissible_values = [x for x in range(min_value, max_value + 1)]

    value = (value / scale) * int_scale
    expected_output = tensor_clamp(value,
                                   min_val=min_int(signed, narrow_range,
                                                   bit_width),
                                   max_val=max_int(signed, bit_width))
    expected_output = (expected_output / int_scale) * scale

    int_output = obj.to_int(scale, int_scale, bit_width, value)

    # The assert is performed internally check_admissible_values
    check_admissible_values(int_output, admissible_values)
    assert torch.allclose(expected_output, output, RTOL, tol)
Exemplo n.º 4
0
    def __init__(self,
                 signed: bool,
                 narrow_range: bool,
                 quant_type: QuantType,
                 ms_bit_width_to_clamp: int,
                 clamp_at_least_init_val: bool,
                 min_overall_bit_width: Optional[int],
                 max_overall_bit_width: Optional[int],
                 msb_clamp_bit_width_impl_type: BitWidthImplType,
                 override_pretrained_bit_width: bool):
        super(ClampQuantProxy, self).__init__()

        if quant_type == QuantType.FP:
            self.tensor_quant = IdentityPrescaledIntQuant()

        elif quant_type == QuantType.INT:
            msb_clamp_bit_width_impl = MsbClampParameterBitWidth(ms_bit_width_to_clamp=ms_bit_width_to_clamp,
                                                                 clamp_at_least_init_val=clamp_at_least_init_val,
                                                                 min_overall_bit_width=min_overall_bit_width,
                                                                 max_overall_bit_width=max_overall_bit_width,
                                                                 bit_width_impl_type=msb_clamp_bit_width_impl_type,
                                                                 override_pretrained=override_pretrained_bit_width)
            tensor_clamp_impl = TensorClamp()
            float_to_int_impl = RestrictValue(restrict_value_type=RestrictValueType.INT,
                                              float_to_int_impl_type=FloatToIntImplType.ROUND,
                                              min_val=None)
            tensor_quant_impl = PrescaledRestrictIntQuantWithInputBitWidth
            self.tensor_quant = tensor_quant_impl(signed=signed,
                                                  narrow_range=narrow_range,
                                                  tensor_clamp_impl=tensor_clamp_impl,
                                                  float_to_int_impl=float_to_int_impl,
                                                  msb_clamp_bit_width_impl=msb_clamp_bit_width_impl)
        else:
            raise Exception("Quantization type {} not supported for accumulators.".format(quant_type))
Exemplo n.º 5
0
    def __init__(self,
                 signed: bool,
                 quant_type: QuantType,
                 ls_bit_width_to_trunc: int,
                 trunc_at_least_init_val: bool,
                 min_overall_bit_width: Optional[int],
                 max_overall_bit_width: Optional[int],
                 lsb_trunc_bit_width_impl_type: BitWidthImplType,
                 explicit_rescaling: bool,
                 override_pretrained_bit_width: bool):
        super(TruncQuantProxy, self).__init__()
        self.register_buffer(ZERO_HW_SENTINEL_NAME, torch.tensor(ZERO_HW_SENTINEL_VALUE))
        self.explicit_rescaling = explicit_rescaling

        if quant_type == QuantType.FP:
            self.lsb_trunc_bit_width_impl = ZeroLsbTruncBitWidth()
            self.tensor_quant = IdentityPrescaledIntQuant()

        elif quant_type == QuantType.INT:
            self.lsb_trunc_bit_width_impl = LsbTruncParameterBitWidth(ls_bit_width_to_trunc=ls_bit_width_to_trunc,
                                                                      trunc_at_least_init_val=trunc_at_least_init_val,
                                                                      min_overall_bit_width=min_overall_bit_width,
                                                                      max_overall_bit_width=max_overall_bit_width,
                                                                      bit_width_impl_type=lsb_trunc_bit_width_impl_type,
                                                                      override_pretrained=override_pretrained_bit_width)
            tensor_clamp_impl = TensorClamp()
            float_to_int_impl = RestrictValue(restrict_value_type=RestrictValueType.INT,
                                              float_to_int_impl_type=FloatToIntImplType.FLOOR,
                                              min_val=None)
            self.tensor_quant = PrescaledIntQuant(signed=signed,
                                                  narrow_range=False,
                                                  tensor_clamp_impl=tensor_clamp_impl,
                                                  float_to_int_impl=float_to_int_impl)
        else:
            raise Exception("Quantization type {} not supported for accumulators.".format(quant_type))
Exemplo n.º 6
0
 def __init__(
         self,
         scaling_impl: Module,
         tensor_clamp_impl: Module = TensorClamp(),
         quant_delay_steps: int = 0):
     super(ClampedBinaryQuant, self).__init__()
     self.scaling_impl = scaling_impl
     self.bit_width = BitWidthConst(1)
     self.zero_point = StatelessBuffer(torch.tensor(0.0))
     self.delay_wrapper = DelayWrapper(quant_delay_steps)
     self.tensor_clamp_impl = tensor_clamp_impl
Exemplo n.º 7
0
 def __init__(self,
              narrow_range: bool,
              signed: bool,
              float_to_int_impl: Module = RoundSte(),
              tensor_clamp_impl: Module = TensorClamp(),
              quant_delay_steps: int = 0):
     super(IntQuant, self).__init__()
     self.float_to_int_impl = float_to_int_impl
     self.tensor_clamp_impl = tensor_clamp_impl
     self.signed = signed
     self.narrow_range = narrow_range
     self.delay_wrapper = DelayWrapper(quant_delay_steps)
Exemplo n.º 8
0
def test_PrescaledRestrictIntQuanth(x, narrow_range, signed, scale, bit_width):
    value = torch.tensor(x)
    scale = torch.tensor(scale)
    bit_width = torch.tensor(bit_width, dtype=torch.float)
    tensor_clamp_impl = TensorClamp()

    msb_clamp_bitwidth_mock = Mock()
    msb_clamp_bitwidth_mock.return_value = bit_width
    float_to_int_impl_mock = Mock()
    float_to_int_impl_mock.side_effect = (lambda y: y)
    int_quant = IntQuant(signed=signed,
                         narrow_range=narrow_range,
                         tensor_clamp_impl=tensor_clamp_impl,
                         float_to_int_impl=float_to_int_impl_mock)
    obj = PrescaledRestrictIntQuant(bit_width_impl=msb_clamp_bitwidth_mock,
                                    int_quant=int_quant)

    output, scale, bit_width = obj(value, scale)

    expected_output = int_quant(scale, torch.tensor(1.0), bit_width, value)

    assert torch.allclose(expected_output, output, RTOL, ATOL)
Exemplo n.º 9
0
def test_PrescaledRestrictIntQuantWithInputBitWidth(x, narrow_range, signed, scale, bit_width):
    value = torch.tensor(x)
    scale = torch.tensor(scale)
    tensor_clamp_impl = TensorClamp()

    msb_clamp_bitwidth_mock = Mock()
    msb_clamp_bitwidth_mock.return_value = torch.tensor(bit_width, dtype=torch.float)
    float_to_int_impl_mock = Mock()
    float_to_int_impl_mock.side_effect = (lambda y: y)

    obj = PrescaledRestrictIntQuantWithInputBitWidth(narrow_range=narrow_range, signed=signed,
                                                     tensor_clamp_impl=tensor_clamp_impl,
                                                     msb_clamp_bit_width_impl=msb_clamp_bitwidth_mock,
                                                     float_to_int_impl=float_to_int_impl_mock)

    output, scale, bit_width = obj(value, scale, bit_width, torch.tensor(ZERO_HW_SENTINEL_VALUE))

    expected_IntQuant = IntQuant(signed=signed, narrow_range=narrow_range, tensor_clamp_impl=tensor_clamp_impl,
                                 float_to_int_impl=float_to_int_impl_mock)
    expected_output = expected_IntQuant(scale, torch.tensor(ZERO_HW_SENTINEL_VALUE) + 1, bit_width, value)

    assert torch.allclose(expected_output, output, RTOL, ATOL)
def _weight_quant_init_impl(bit_width: Optional[int],
                            quant_type: QuantType,
                            narrow_range: bool,
                            scaling_override: Optional[nn.Module],
                            restrict_scaling_type: RestrictValueType,
                            scaling_const: float,
                            scaling_stats_op: StatsOp,
                            scaling_impl_type: ScalingImplType,
                            scaling_stats_reduce_dim: Optional[int],
                            scaling_shape: Tuple[int, ...],
                            scaling_min_val: Optional[float],
                            bit_width_impl_type: Optional[BitWidthImplType],
                            restrict_bit_width_type: Optional[RestrictValueType],
                            min_overall_bit_width: Optional[int],
                            max_overall_bit_width: Optional[int],
                            bit_width_impl_override: Optional[Union[BitWidthConst, BitWidthParameter]],
                            scaling_stats_input_view_shape_impl: StatsInputViewShapeImpl,
                            scaling_stats_input_concat_dim: int,
                            ternary_threshold: Optional[float],
                            scaling_stats_sigma: Optional[float],
                            tracked_parameter_list: List[torch.nn.Parameter],
                            zero_hw_sentinel: torch.Tensor,
                            override_pretrained_bit_width: bool):

    if quant_type == QuantType.FP:
        tensor_quant = IdentityQuant()
    else:
        if scaling_impl_type != ScalingImplType.OVERRIDE and scaling_override is not None:
            raise Exception("Overriding scaling requires to set ScalingImplType to OVERRIDE explicitly.")
        if scaling_impl_type == ScalingImplType.OVERRIDE and scaling_override is None:
            raise Exception("Overriding scaling requires to pass a scaling impl module.")

        if scaling_impl_type == ScalingImplType.OVERRIDE and scaling_override is not None:
            scaling_impl = scaling_override

        elif scaling_impl_type == ScalingImplType.STATS \
                or scaling_impl_type == ScalingImplType.AFFINE_STATS \
                or scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS:
            stats_scaling = ParameterStatsScaling(stats_op=scaling_stats_op,
                                                  restrict_scaling_type=restrict_scaling_type,
                                                  tracked_parameter_list=tracked_parameter_list,
                                                  stats_input_view_shape_impl=scaling_stats_input_view_shape_impl,
                                                  stats_input_concat_dim=scaling_stats_input_concat_dim,
                                                  sigma=scaling_stats_sigma,
                                                  scaling_min_val=scaling_min_val,
                                                  stats_reduce_dim=scaling_stats_reduce_dim,
                                                  stats_output_shape=scaling_shape,
                                                  affine=scaling_impl_type == ScalingImplType.AFFINE_STATS)
            if scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS:
                if quant_type == QuantType.BINARY or quant_type == QuantType.TERNARY:
                    raise Exception("Parameter from stats scaling is currently not supported for binary/ternary")
                scaling_init = stats_scaling(zero_hw_sentinel).detach()
                scaling_impl = StandaloneScaling(scaling_init=scaling_init,
                                                 parameter_shape=scaling_shape,
                                                 restrict_scaling_type=restrict_scaling_type,
                                                 is_parameter=True,
                                                 scaling_min_val=scaling_min_val)
            else:
                scaling_impl = stats_scaling

        elif scaling_impl_type == ScalingImplType.CONST or scaling_impl_type == ScalingImplType.HE:
            if scaling_impl_type == ScalingImplType.HE:
                scaling_const = 0.0
                for param in tracked_parameter_list:  # takes average of He scaling over parameter list
                    two_dim_param = param.view(param.shape[0], -1)
                    scaling_const += math.sqrt(2.0 / two_dim_param.shape[1])
                scaling_const /= len(tracked_parameter_list)
            scaling_init = torch.tensor(scaling_const)
            scaling_impl = StandaloneScaling(scaling_init=scaling_init,
                                             parameter_shape=SCALING_SCALAR_SHAPE,
                                             restrict_scaling_type=restrict_scaling_type,
                                             is_parameter=False,
                                             scaling_min_val=None)
        else:
            raise Exception("Scaling type {} not supported for weight quantization"
                            .format(str(scaling_impl_type)))

        if bit_width == 1 and quant_type == QuantType.BINARY:
            tensor_quant = BinaryQuant(scaling_impl=scaling_impl)

        elif bit_width == 2 and quant_type == QuantType.TERNARY:
            tensor_quant = TernaryQuant(scaling_impl=scaling_impl, threshold=ternary_threshold)

        elif bit_width >= 2 and quant_type == QuantType.INT:
            if bit_width_impl_override is None:
                if (bit_width_impl_type is None
                        or bit_width is None
                        or restrict_bit_width_type is None):
                    raise Exception("Bit width is not defined properly")

                if bit_width_impl_type == BitWidthImplType.CONST:
                    tensor_clamp_impl = TensorClampSte()
                    bit_width_impl = BitWidthConst(bit_width, restrict_bit_width_type)
                elif bit_width_impl_type == BitWidthImplType.PARAMETER:
                    tensor_clamp_impl = TensorClamp()
                    bit_width_impl = BitWidthParameter(bit_width_init=bit_width,
                                                       restrict_bit_width_type=restrict_bit_width_type,
                                                       min_overall_bit_width=min_overall_bit_width,
                                                       max_overall_bit_width=max_overall_bit_width,
                                                       override_pretrained=override_pretrained_bit_width)
                else:
                    raise Exception("Bit width type {} not supported for weight quantization."
                                    .format(str(bit_width_impl_type)))
            else:
                tensor_clamp_impl = TensorClamp()
                bit_width_impl = bit_width_impl_override

            float_to_int_impl = RestrictValue(restrict_value_type=RestrictValueType.INT,
                                              float_to_int_impl_type=FloatToIntImplType.ROUND,
                                              min_val=None)
            int_scaling_impl = IntScaling(narrow_range,
                                          signed=True,
                                          restrict_scaling_type=restrict_scaling_type)
            tensor_quant = RescalingIntQuant(narrow_range=narrow_range,
                                             signed=True,
                                             scaling_impl=scaling_impl,
                                             int_scaling_impl=int_scaling_impl,
                                             tensor_clamp_impl=tensor_clamp_impl,
                                             msb_clamp_bit_width_impl=bit_width_impl,
                                             float_to_int_impl=float_to_int_impl,
                                             runtime=False)
        else:
            raise Exception('Unsupported weight quantization: {} bit width, {} quantization.'
                            .format(bit_width, str(quant_type)))
    return tensor_quant
Exemplo n.º 11
0
    def __init__(self,
                 activation_impl: Module,
                 bit_width: int,
                 signed: bool,
                 narrow_range: bool,
                 min_val: float,
                 max_val: float,
                 quant_type: QuantType,
                 float_to_int_impl_type: FloatToIntImplType,
                 scaling_override: Optional[Module],
                 scaling_impl_type: ScalingImplType,
                 scaling_per_channel: bool,
                 scaling_min_val: Optional[float],
                 scaling_stats_sigma: Optional[float],
                 scaling_stats_op: Optional[StatsOp],
                 scaling_stats_buffer_momentum: Optional[float],
                 scaling_stats_permute_dims: Optional[Tuple],
                 per_channel_broadcastable_shape: Optional[Tuple[int, ...]],
                 min_overall_bit_width: Optional[int],
                 max_overall_bit_width: Optional[int],
                 bit_width_impl_override: Module,
                 bit_width_impl_type: BitWidthImplType,
                 restrict_bit_width_type: RestrictValueType,
                 restrict_scaling_type: RestrictValueType,
                 override_pretrained_bit_width: bool):
        super(ActivationQuantProxy, self).__init__()

        if not signed and min_val != 0.0:
            raise Exception("Min val has to be 0.0 when quantization is unsigned.")
        if scaling_per_channel and per_channel_broadcastable_shape is None:
            raise Exception("Per channel scaling requires to specify number of channels.")

        if quant_type == QuantType.FP:
            tensor_quant = IdentityQuant()
        else:
            if scaling_impl_type != ScalingImplType.OVERRIDE and scaling_override is not None:
                raise Exception("Overriding scaling requires to set ScalingImplType to OVERRIDE explicitly.")
            if scaling_impl_type == ScalingImplType.OVERRIDE and scaling_override is None:
                raise Exception("Overriding scaling requires to pass a scaling impl module.")

            if scaling_per_channel:
                scaling_shape = per_channel_broadcastable_shape
            else:
                scaling_shape = SCALING_SCALAR_SHAPE

            if scaling_impl_type == ScalingImplType.OVERRIDE and scaling_override is not None:
                scaling_impl = scaling_override
                runtime = False

            elif scaling_impl_type == ScalingImplType.CONST or scaling_impl_type == ScalingImplType.PARAMETER:
                scaling_init = RescalingIntQuant.scaling_init_from_min_max(min_val, max_val)
                scaling_impl = StandaloneScaling(is_parameter=scaling_impl_type == ScalingImplType.PARAMETER,
                                                 parameter_shape=scaling_shape,
                                                 restrict_scaling_type=restrict_scaling_type,
                                                 scaling_init=scaling_init,
                                                 scaling_min_val=scaling_min_val)
                runtime = False
            elif scaling_impl_type == ScalingImplType.STATS or scaling_impl_type == ScalingImplType.AFFINE_STATS:

                if scaling_per_channel and not scaling_stats_op == StatsOp.MAX_AVE:
                    scaling_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                    scaling_stats_reduce_dim = 1
                elif scaling_per_channel and scaling_stats_op == StatsOp.MAX_AVE:
                    raise Exception("Can't do per channel scaling with MAX AVE statistics.")
                elif not scaling_per_channel and scaling_stats_op == StatsOp.MAX_AVE:
                    scaling_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
                    scaling_stats_reduce_dim = 1
                else:  # not scaling_per_channel
                    scaling_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_TENSOR
                    scaling_stats_reduce_dim = None
                    scaling_stats_permute_dims = None

                stats_buffer_init = RescalingIntQuant.scaling_init_from_min_max(min_val, max_val).item()
                scaling_impl = RuntimeStatsScaling(stats_op=scaling_stats_op,
                                                   restrict_scaling_type=restrict_scaling_type,
                                                   stats_input_view_shape_impl=scaling_stats_input_view_shape_impl,
                                                   stats_output_shape=scaling_shape,
                                                   sigma=scaling_stats_sigma,
                                                   scaling_min_val=scaling_min_val,
                                                   stats_reduce_dim=scaling_stats_reduce_dim,
                                                   stats_buffer_momentum=scaling_stats_buffer_momentum,
                                                   stats_buffer_init=stats_buffer_init,
                                                   stats_permute_dims=scaling_stats_permute_dims,
                                                   affine=scaling_impl_type == ScalingImplType.AFFINE_STATS)
                runtime = True
            else:
                raise Exception("Scaling type {} not supported for int runtime quantization"
                                .format(str(scaling_impl_type)))

            if quant_type == QuantType.BINARY:
                if not signed:
                    raise Exception("Binary activation supports only signed activations")
                tensor_quant = ClampedBinaryQuant(scaling_impl=scaling_impl)

            elif quant_type == QuantType.INT:

                if bit_width_impl_override is None:
                    if bit_width_impl_type is None or bit_width is None or restrict_bit_width_type is None:
                        raise Exception("Bit width is not defined properly")

                    if bit_width_impl_type == BitWidthImplType.CONST:
                        tensor_clamp_impl = TensorClamp()  # If it's const, don't pass gradients to clipped values
                        msb_clamp_bit_width_impl = BitWidthConst(bit_width, restrict_bit_width_type)
                    elif bit_width_impl_type == BitWidthImplType.PARAMETER:
                        tensor_clamp_impl = TensorClamp()  # if it's learned, I pass gradients to the bit width
                        msb_clamp_bit_width_impl = BitWidthParameter(bit_width,
                                                                     min_overall_bit_width,
                                                                     max_overall_bit_width,
                                                                     restrict_bit_width_type,
                                                                     override_pretrained_bit_width)
                    else:
                        raise Exception("Bit width type {} not supported for weight quantization"
                                        .format(str(bit_width_impl_type)))
                else:
                    msb_clamp_bit_width_impl = bit_width_impl_override
                    tensor_clamp_impl = TensorClamp()  # if there is an override, it's learned

                float_to_int_impl = RestrictValue(restrict_value_type=RestrictValueType.INT,
                                                  float_to_int_impl_type=float_to_int_impl_type,
                                                  min_val=None)
                int_scaling_impl = IntScaling(narrow_range,
                                              signed=signed,
                                              restrict_scaling_type=restrict_scaling_type)
                tensor_quant = RescalingIntQuant(signed=signed,
                                                 narrow_range=narrow_range,
                                                 scaling_impl=scaling_impl,
                                                 int_scaling_impl=int_scaling_impl,
                                                 tensor_clamp_impl=tensor_clamp_impl,
                                                 msb_clamp_bit_width_impl=msb_clamp_bit_width_impl,
                                                 float_to_int_impl=float_to_int_impl,
                                                 runtime=runtime)
            else:
                raise Exception("Quantization type {} not supported for activations.".format(quant_type))

        self.fused_activation_quant_proxy = FusedActivationQuantProxy(activation_impl, tensor_quant)
        self.scaling_impl_type = scaling_impl_type  # needed to switch between different scaling modes