コード例 #1
0
    def __init__(self, signed: bool, quant_type: QuantType,
                 ls_bit_width_to_trunc: int, trunc_at_least_init_val: bool,
                 min_overall_bit_width: Optional[int],
                 max_overall_bit_width: Optional[int],
                 lsb_trunc_bit_width_impl_type: BitWidthImplType,
                 explicit_rescaling: bool,
                 override_pretrained_bit_width: bool):
        super(TruncQuantProxy, self).__init__()
        self.explicit_rescaling = explicit_rescaling

        if quant_type == QuantType.FP:
            self.lsb_trunc_bit_width_impl = ZeroLsbTruncBitWidth()
            self.tensor_quant = IdentityPrescaledIntQuant()

        elif quant_type == QuantType.INT:
            self.lsb_trunc_bit_width_impl = LsbTruncParameterBitWidth(
                ls_bit_width_to_trunc=ls_bit_width_to_trunc,
                trunc_at_least_init_val=trunc_at_least_init_val,
                min_overall_bit_width=min_overall_bit_width,
                max_overall_bit_width=max_overall_bit_width,
                bit_width_impl_type=lsb_trunc_bit_width_impl_type,
                override_pretrained=override_pretrained_bit_width)
            tensor_clamp_impl = TensorClamp()
            float_to_int_impl = RestrictValue(
                restrict_value_type=RestrictValueType.INT,
                float_to_int_impl_type=FloatToIntImplType.FLOOR,
                min_val=None)
            msb_clamp_bit_width_impl = IdentityBitWidth()
            self.tensor_quant = PrescaledRestrictIntQuantWithInputBitWidth(
                narrow_range=False,
                signed=signed,
                tensor_clamp_impl=tensor_clamp_impl,
                msb_clamp_bit_width_impl=msb_clamp_bit_width_impl,
                float_to_int_impl=float_to_int_impl)

        else:
            raise Exception(
                "Quantization type {} not supported for accumulators.".format(
                    quant_type))
コード例 #2
0
ファイル: test_quant.py プロジェクト: solitary-1/brevitas
def test_PrescaledRestrictIntQuantWithInputBitWidth(x, narrow_range, signed, scale, bit_width):
    value = torch.tensor(x)
    scale = torch.tensor(scale)
    tensor_clamp_impl = TensorClamp()

    msb_clamp_bitwidth_mock = Mock()
    msb_clamp_bitwidth_mock.return_value = torch.tensor(bit_width, dtype=torch.float)
    float_to_int_impl_mock = Mock()
    float_to_int_impl_mock.side_effect = (lambda y: y)

    obj = PrescaledRestrictIntQuantWithInputBitWidth(narrow_range=narrow_range, signed=signed,
                                                     tensor_clamp_impl=tensor_clamp_impl,
                                                     msb_clamp_bit_width_impl=msb_clamp_bitwidth_mock,
                                                     float_to_int_impl=float_to_int_impl_mock)

    output, scale, bit_width = obj(value, scale, bit_width, torch.tensor(ZERO_HW_SENTINEL_VALUE))

    expected_IntQuant = IntQuant(signed=signed, narrow_range=narrow_range, tensor_clamp_impl=tensor_clamp_impl,
                                 float_to_int_impl=float_to_int_impl_mock)
    expected_output = expected_IntQuant(scale, torch.tensor(ZERO_HW_SENTINEL_VALUE) + 1, bit_width, value)

    assert torch.allclose(expected_output, output, RTOL, ATOL)
コード例 #3
0
ファイル: parameter_quant.py プロジェクト: xrick/brevitas
    def __init__(self, quant_type: QuantType, bit_width: Optional[int],
                 narrow_range: bool) -> None:
        super(BiasQuantProxy, self).__init__()
        self.scale_output_shape = OVER_BATCH_OVER_CHANNELS_SHAPE

        if quant_type == QuantType.FP:
            self.tensor_quant = None
        elif quant_type == QuantType.INT:
            tensor_clamp_impl = TensorClamp()
            float_to_int_impl = RestrictValue(
                restrict_value_type=RestrictValueType.INT,
                float_to_int_impl_type=FloatToIntImplType.ROUND,
                min_val=None)
            if bit_width is not None:
                bit_width_impl = BitWidthConst(
                    bit_width, restrict_bit_width_type=RestrictValueType.INT)
                self.tensor_quant = PrescaledRestrictIntQuant(
                    narrow_range=narrow_range,
                    signed=True,
                    tensor_clamp_impl=tensor_clamp_impl,
                    msb_clamp_bit_width_impl=bit_width_impl,
                    float_to_int_impl=float_to_int_impl)
                self.requires_input_bit_width = False
            else:
                msb_clamp_bit_width_impl = IdentityBitWidth()
                self.tensor_quant = PrescaledRestrictIntQuantWithInputBitWidth(
                    narrow_range=narrow_range,
                    signed=True,
                    tensor_clamp_impl=tensor_clamp_impl,
                    msb_clamp_bit_width_impl=msb_clamp_bit_width_impl,
                    float_to_int_impl=float_to_int_impl)
                self.requires_input_bit_width = True
        else:
            raise Exception(
                'Quantization type {} not supported for bias quant.'.format(
                    str(quant_type)))