Пример #1
0
 def quant_type(
         module: QuantReLU,
         supported_int_bit_width_range: Tuple[int,...] = (2, 33)):
     bit_width = int(module.quant_act_bit_width().item())
     if bit_width in range(*supported_int_bit_width_range):
         return f"UINT{bit_width}"
     else:
         raise RuntimeError(f"Unsupported input bit width {bit_width} for export")
Пример #2
0
 def quant_type(
         module: QuantReLU,
         supported_bit_width: Tuple[int,...] = (2, 4, 8, 16, 32)):
     bit_width = int(module.quant_act_bit_width().item())
     if bit_width in list(supported_bit_width):
         return f"UINT{bit_width}"
     else:
         raise RuntimeError(f"Unsupported input bit width {bit_width} for export")
Пример #3
0
 def thresholds(module: QuantReLU):
     num_distinct_values = 2**int(module.quant_act_bit_width().item())
     num_thresholds = num_distinct_values - 1
     flat_scale = module.quant_act_scale().view(-1)
     num_scale_channels = flat_scale.shape[0]
     step = torch.abs(flat_scale)
     min_threshold = step / 2
     thresholds = torch.empty(num_scale_channels, num_thresholds)
     for c in range(num_scale_channels):
         for t in range(num_thresholds):
             thresholds[c][t] = min_threshold[c] + step[c] * t
     return thresholds
Пример #4
0
 def thresholds(module: QuantReLU, extend_tensor_to_channels=True):
     num_distinct_values = 2 ** int(module.quant_act_bit_width().item())
     num_thresholds = num_distinct_values - 1
     flat_scale = module.quant_act_scale().view(-1)
     num_scale_channels = flat_scale.shape[0]
     step = torch.abs(flat_scale)
     min_threshold = step / 2
     thresholds = torch.empty(num_scale_channels, num_thresholds)
     for c in range(num_scale_channels):
         for t in range(num_thresholds):
             thresholds[c][t] = min_threshold[c] + step[c] * t
     if extend_tensor_to_channels:
         output_channels = module._cached_inp.shape[1]
         final_shape = (output_channels, num_thresholds)
         if thresholds.shape != final_shape:
             thresholds = thresholds.expand(final_shape)
     return thresholds
Пример #5
0
 def quant_type(module: QuantReLU):
     bit_width = module.quant_act_bit_width()
     signed = module.is_quant_act_signed
     return finn_datatype(bit_width, signed)