def prepare_weight_quant(cls, module: QuantWBIOL): cls.validate_8b_bit_width(module.quant_weight_bit_width()) scale = module.quant_weight_scale() zero_point = cls.quant_weight_zero_point(module) signed = module.is_quant_weight_signed weight = module.weight.detach() quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed) return quant_impl, (weight,), quant_kwargs
def validate(cls, module: QuantWBIOL, requires_quant_bias=True): assert module.is_weight_quant_enabled assert module.is_output_quant_enabled cls.validate_8b_bit_width(module.quant_weight_bit_width(), le_then=True) cls.validate_8b_bit_width(module.quant_input_bit_width()) cls.validate_8b_bit_width(module.quant_output_bit_width()) if module.bias is not None and requires_quant_bias: assert module.is_bias_quant_enabled assert module.is_quant_bias_signed cls.validate_32b_bit_width(module.quant_bias_bit_width(), le_then=True)
def quant_weight_type(module: QuantWBIOL): bit_width = int(module.quant_weight_bit_width().item()) if bit_width == 1: return "BIPOLAR" else: return f"INT{bit_width}"
def quant_weight_bit_width(module: QuantWBIOL): bit_width = module.quant_weight_bit_width() return DPUQuantLayerHandler.validate_8b_bit_width(bit_width)
def test_default_wbiol_weight_bit_width_enabled( default_wbiol_layer: QuantWBIOL): assert default_wbiol_layer.quant_weight_bit_width() == torch.tensor(8.)
def quant_weight_bit_width(cls, module: QuantWBIOL): bit_width = module.quant_weight_bit_width() return cls.validate_8b_bit_width(bit_width, le_then=True)
def quant_weight_type(module: QuantWBIOL): return finn_datatype(module.quant_weight_bit_width(), module.is_quant_weight_signed)