예제 #1
0
 def thresholds(module: QuantHardTanh, extend_tensor_to_channels=True):
     bit_width = int(module.quant_act_bit_width().item())
     if bit_width != 1:
         if module.is_quant_act_narrow_range:
             # assuming narrow range, symmetric quantization around zero
             # when using narrow range, we represent one element less
             num_distinct_values = 2 ** bit_width - 1
         else:
             num_distinct_values = 2 ** bit_width
         num_thresholds = num_distinct_values - 1
         flat_scale = module.quant_act_scale().view(-1)
         num_scale_channels = flat_scale.shape[0]
         step = torch.abs(flat_scale)
         half_step = step / 2.0
         thresholds = torch.empty(num_scale_channels, num_thresholds)
         # compute the value of the smallest threshold, we'll neg-bias all
         # generated thresholds by this much
         min_threshold = - half_step - step * ((num_thresholds // 2) - 1)
         if not module.is_quant_act_narrow_range:
             min_threshold -= step
         for c in range(num_scale_channels):
             for t in range(num_thresholds):
                 thresholds[c][t] = min_threshold[c] + step[c] * t
         if extend_tensor_to_channels:
             output_channels = module._cached_inp.shape[1]
             final_shape = (output_channels, num_thresholds)
             if thresholds.shape != final_shape:
                 thresholds = thresholds.expand(final_shape)
         return thresholds
     else:
         thresholds = torch.empty([1, 1])
         thresholds[0] = 0
         return thresholds
예제 #2
0
 def quant_act_scale(module: QuantHardTanh):
     bit_width = int(module.quant_act_bit_width().item())
     quant_act_scale = module.quant_act_scale().type(torch.FloatTensor).detach()
     if bit_width != 1:
         return quant_act_scale
     else:
         assert quant_act_scale.view(-1).shape[0] == 1, "Unsupported BIPOLAR per channel scale"
         assert quant_act_scale.flatten().item() == 1.0, "Unsupported BIPOLAR scale != 1"
         return quant_act_scale * 2
def test_brevitas_act_export_qhardtanh_scaled(abits, narrow_range, min_val,
                                              max_val, scaling_impl_type):
    def get_quant_type(bit_width):
        if bit_width is None:
            return QuantType.FP
        elif bit_width == 1:
            return QuantType.BINARY
        else:
            return QuantType.INT

    act_quant_type = get_quant_type(abits)
    ishape = (1, 15)
    b_act = QuantHardTanh(
        bit_width=abits,
        quant_type=act_quant_type,
        max_val=max_val,
        min_val=min_val,
        restrict_scaling_type=RestrictValueType.LOG_FP,
        scaling_impl_type=scaling_impl_type,
        narrow_range=narrow_range,
    )
    if scaling_impl_type == ScalingImplType.PARAMETER:
        checkpoint = {
            "act_quant_proxy.fused_activation_quant_proxy.\
tensor_quant.scaling_impl.learned_value":
            torch.tensor(0.49).type(torch.FloatTensor)
        }
        b_act.load_state_dict(checkpoint)

    bo.export_finn_onnx(b_act, ishape, export_onnx_path)
    model = ModelWrapper(export_onnx_path)
    model = model.transform(InferShapes())
    inp_tensor = np.random.uniform(low=min_val, high=max_val,
                                   size=ishape).astype(np.float32)
    idict = {model.graph.input[0].name: inp_tensor}
    odict = oxe.execute_onnx(model, idict, True)
    produced = odict[model.graph.output[0].name]
    inp_tensor = torch.from_numpy(inp_tensor).float()
    b_act.eval()
    expected = b_act.forward(inp_tensor).detach().numpy()
    if not np.isclose(produced, expected, atol=1e-3).all():
        print(
            "abits: ",
            abits,
            " | narrow_range: ",
            narrow_range,
            " | min_val: ",
            min_val,
            " | max_val: ",
            max_val,
        )
        print("layer scale: ",
              b_act.quant_act_scale().type(torch.FloatTensor).detach())
        print("export scale: ", b_act.export_act_scale)
        if abits < 5:
            print(
                "thres:",
                ", ".join(["{:8.4f}".format(x)
                           for x in b_act.export_thres[0]]),
            )
        print("input:",
              ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]]))
        print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]]))
        print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]]))

    assert np.isclose(produced, expected, atol=1e-3).all()
    os.remove(export_onnx_path)