def prepare_weight_quant(cls, module: QuantWBIOL): cls.validate_8b_bit_width(module.quant_weight_bit_width()) scale = module.quant_weight_scale() zero_point = cls.quant_weight_zero_point(module) signed = module.is_quant_weight_signed weight = module.weight.detach() quant_impl, quant_kwargs = cls.gen_quant_impl_kwargs(scale, zero_point, signed) return quant_impl, (weight,), quant_kwargs
def validate(cls, module: QuantWBIOL, requires_quant_bias=True): assert module.is_weight_quant_enabled assert module.is_output_quant_enabled cls.validate_8b_bit_width(module.quant_weight_bit_width(), le_then=True) cls.validate_8b_bit_width(module.quant_input_bit_width()) cls.validate_8b_bit_width(module.quant_output_bit_width()) if module.bias is not None and requires_quant_bias: assert module.is_bias_quant_enabled assert module.is_quant_bias_signed cls.validate_32b_bit_width(module.quant_bias_bit_width(), le_then=True)
def maybe_quant_bias(module: QuantWBIOL, quant_weight_scale: Tensor, quant_output_scale: Optional[Union[Tensor, float]], quant_bit_width: Optional[Tensor]): if module.bias is None: return None elif module.bias is not None and not module.is_bias_quant_enabled: bias = torch.t(module.bias.type(torch.FloatTensor)).detach() # account for both scalar tensors and array tensors with 1 elements if len(quant_weight_scale.shape) > 0 and len( quant_weight_scale) > 1: quant_weight_scale = quant_weight_scale.reshape(bias.shape) # divide by weight scale as add is before mul bias /= quant_weight_scale return bias else: # bias quant enabled assert quant_output_scale, 'Quant bias export requires caching of the output scale' if not isinstance(quant_output_scale, Tensor): # item might have been called quant_output_scale = torch.tensor(quant_output_scale) quant_bias = module.bias_quant(module.bias, quant_output_scale, quant_bit_width) quant_bias = torch.t(quant_bias.value.type( torch.FloatTensor)).detach() quant_bias /= quant_output_scale quant_bias = torch.round(quant_bias) return quant_bias
def quant_bias_bit_width(cls, module: QuantWBIOL): if module.bias is not None: bit_width = module.quant_bias_bit_width() return DPUQuantLayerHandler.validate_8b_bit_width(bit_width, le_then=True) else: return None
def prepare_bias(cls, module: QuantWBIOL): if module.bias is not None and not module.is_bias_quant_enabled: bias = module.bias.detach() elif module.bias is not None and module.is_bias_quant_enabled: bias = module.quant_bias() else: bias = module.bias return bias
def prepare_bias(cls, module: QuantWBIOL): if module.bias is not None and not module.is_bias_quant_enabled: bias = module.bias.detach() elif module.bias is not None and module.is_bias_quant_enabled: # export the dequantized value bias = module.quant_bias().value else: bias = module.bias return bias
def maybe_int_bias(module: QuantWBIOL): if module.bias is not None: if module.is_bias_quant_enabled: bias = module.int_bias(float_datatype=True) else: bias = module.bias bias = torch.t(bias).detach() else: bias = None return bias
def maybe_int_bias(module: QuantWBIOL): if module.bias is not None: if module.is_bias_quant_enabled: bias = module.int_bias(float_datatype=True) else: bias = module.bias bias_shape = [1] * len(module.weight.shape) bias_shape[1] = -1 # shape should broadcast with activations along channel dim bias = bias.view(bias_shape).detach() else: bias = None return bias
def quant_bias_scale(module: QuantWBIOL): if module.bias is not None: scale = module.quant_bias_scale() return DPUQuantLayerHandler.neg_scalar_exponent_from_scale(scale) else: return None
def int_weight(module: QuantWBIOL): return module.int_weight(float_datatype=False).detach()
def maybe_quant_bias_scale(module: QuantWBIOL): if module.is_bias_quant_enabled: return module.quant_bias_scale() else: return None
def maybe_quant_bias_type(module: QuantWBIOL): if module.is_bias_quant_enabled: return finn_datatype(module.quant_bias_bit_width(), module.is_quant_bias_signed) else: return None
def int_weight(module: QuantWBIOL): int_weight = module.int_weight(float_datatype=False).detach() return int_weight.type(torch.int8)
def quant_weight_type(module: QuantWBIOL): return finn_datatype(module.quant_weight_bit_width(), module.is_quant_weight_signed)
def int_weight(module: QuantWBIOL): return torch.t(module.int_weight(float_datatype=True)).detach()
def quant_weight_bit_width(cls, module: QuantWBIOL): bit_width = module.quant_weight_bit_width() return cls.validate_8b_bit_width(bit_width, le_then=True)
def int_bias(module: QuantWBIOL): if module.bias is not None: return module.int_bias(float_datatype=False).detach() else: return None
def quant_weight_bit_width(module: QuantWBIOL): bit_width = module.quant_weight_bit_width() return DPUQuantLayerHandler.validate_8b_bit_width(bit_width)
def quant_weight_scale(cls, module: QuantWBIOL): quant_weight_scale = module.quant_weight_scale() return cls.validate_neg_scalar_int_exponent(quant_weight_scale)
def quant_weight_scale(module: QuantWBIOL): quant_weight_scale = module.quant_weight_scale() return DPUQuantLayerHandler.neg_scalar_exponent_from_scale( quant_weight_scale)
def quant_bias_scale(cls, module: QuantWBIOL): if module.bias is not None: scale = module.quant_bias_scale() return cls.validate_neg_scalar_int_exponent(scale) else: return None
def int_weight(module: QuantWBIOL): int_weight = module.int_weight(float_datatype=False).detach() if module.is_quant_weight_signed: return int_weight.type(torch.int8) else: return int_weight.type(torch.uint8)
def test_default_wbiol_quant_weight_zero_point( default_wbiol_layer: QuantWBIOL): assert default_wbiol_layer.quant_weight_zero_point() == torch.tensor(0.)
def quant_weight_type(module: QuantWBIOL): bit_width = int(module.quant_weight_bit_width().item()) if bit_width == 1: return "BIPOLAR" else: return f"INT{bit_width}"
def test_default_wbiol_quant_bias_zero_point(default_wbiol_layer: QuantWBIOL): assert default_wbiol_layer.quant_bias_zero_point() is None
def quant_weight_scale(module: QuantWBIOL): return torch.t(module.quant_weight_scale().type( torch.FloatTensor)).detach()
def test_default_wbiol_quant_bias_scale(default_wbiol_layer: QuantWBIOL): assert default_wbiol_layer.quant_bias_scale() is None
def int_bias(module: QuantWBIOL): if module.bias is not None: int_bias = module.int_bias(float_datatype=False).detach() return int_bias.type(torch.int32) else: return None
def test_default_wbiol_weight_bit_width_enabled( default_wbiol_layer: QuantWBIOL): assert default_wbiol_layer.quant_weight_bit_width() == torch.tensor(8.)