def quant_type( module: QuantReLU, supported_int_bit_width_range: Tuple[int,...] = (2, 33)): bit_width = int(module.quant_act_bit_width().item()) if bit_width in range(*supported_int_bit_width_range): return f"UINT{bit_width}" else: raise RuntimeError(f"Unsupported input bit width {bit_width} for export")
def quant_type( module: QuantReLU, supported_bit_width: Tuple[int,...] = (2, 4, 8, 16, 32)): bit_width = int(module.quant_act_bit_width().item()) if bit_width in list(supported_bit_width): return f"UINT{bit_width}" else: raise RuntimeError(f"Unsupported input bit width {bit_width} for export")
def thresholds(module: QuantReLU): num_distinct_values = 2**int(module.quant_act_bit_width().item()) num_thresholds = num_distinct_values - 1 flat_scale = module.quant_act_scale().view(-1) num_scale_channels = flat_scale.shape[0] step = torch.abs(flat_scale) min_threshold = step / 2 thresholds = torch.empty(num_scale_channels, num_thresholds) for c in range(num_scale_channels): for t in range(num_thresholds): thresholds[c][t] = min_threshold[c] + step[c] * t return thresholds
def thresholds(module: QuantReLU, extend_tensor_to_channels=True): num_distinct_values = 2 ** int(module.quant_act_bit_width().item()) num_thresholds = num_distinct_values - 1 flat_scale = module.quant_act_scale().view(-1) num_scale_channels = flat_scale.shape[0] step = torch.abs(flat_scale) min_threshold = step / 2 thresholds = torch.empty(num_scale_channels, num_thresholds) for c in range(num_scale_channels): for t in range(num_thresholds): thresholds[c][t] = min_threshold[c] + step[c] * t if extend_tensor_to_channels: output_channels = module._cached_inp.shape[1] final_shape = (output_channels, num_thresholds) if thresholds.shape != final_shape: thresholds = thresholds.expand(final_shape) return thresholds
def quant_type(module: QuantReLU): bit_width = module.quant_act_bit_width() signed = module.is_quant_act_signed return finn_datatype(bit_width, signed)