コード例 #1
0
ファイル: test_act_scaling.py プロジェクト: uslumt/brevitas
    def test_scaling_stats_to_parameter(self):

        stats_act = QuantReLU(bit_width=BIT_WIDTH,
                              max_val=MAX_VAL,
                              quant_type=QuantType.INT,
                              scaling_impl_type=ScalingImplType.STATS,
                              scaling_stats_permute_dims=None,
                              scaling_stats_op=StatsOp.MAX)
        stats_act.train()
        for i in range(RANDOM_ITERS):
            inp = torch.randn([8, 3, 64, 64])
            stats_act(inp)

        stats_state_dict = stats_act.state_dict()

        param_act = QuantReLU(bit_width=BIT_WIDTH,
                              max_val=MAX_VAL,
                              quant_type=QuantType.INT,
                              scaling_impl_type=ScalingImplType.PARAMETER)
        param_act.load_state_dict(stats_state_dict)

        stats_act.eval()
        param_act.eval()

        assert (torch.allclose(stats_act.quant_act_scale(),
                               param_act.quant_act_scale()))
コード例 #2
0
def test_brevitas_act_export_relu(abits, max_val, scaling_impl_type,
                                  QONNX_export):
    min_val = -1.0
    ishape = (1, 15)

    b_act = QuantReLU(
        bit_width=abits,
        max_val=max_val,
        scaling_impl_type=scaling_impl_type,
        restrict_scaling_type=RestrictValueType.LOG_FP,
        quant_type=QuantType.INT,
    )
    if scaling_impl_type == ScalingImplType.PARAMETER:
        checkpoint = {
            "act_quant_proxy.fused_activation_quant_proxy.tensor_quant.\
scaling_impl.learned_value":
            torch.tensor(0.49).type(torch.FloatTensor)
        }
        b_act.load_state_dict(checkpoint)
    if QONNX_export:
        m_path = export_onnx_path
        BrevitasONNXManager.export(b_act, ishape, m_path)
        qonnx_cleanup(m_path, out_file=m_path)
        model = ModelWrapper(m_path)
        model = model.transform(ConvertQONNXtoFINN())
        model.save(m_path)
    else:
        bo.export_finn_onnx(b_act, ishape, export_onnx_path)
    model = ModelWrapper(export_onnx_path)
    model = model.transform(InferShapes())
    inp_tensor = np.random.uniform(low=min_val, high=max_val,
                                   size=ishape).astype(np.float32)
    idict = {model.graph.input[0].name: inp_tensor}
    odict = oxe.execute_onnx(model, idict, True)
    produced = odict[model.graph.output[0].name]
    inp_tensor = torch.from_numpy(inp_tensor).float()
    b_act.eval()
    expected = b_act.forward(inp_tensor).detach().numpy()
    if not np.isclose(produced, expected, atol=1e-3).all():
        print(abits, max_val, scaling_impl_type)
        print("scale: ",
              b_act.quant_act_scale().type(torch.FloatTensor).detach())
        if abits < 5:
            print(
                "thres:",
                ", ".join(["{:8.4f}".format(x)
                           for x in b_act.export_thres[0]]),
            )
        print("input:",
              ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]]))
        print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]]))
        print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]]))

    assert np.isclose(produced, expected, atol=1e-3).all()
    os.remove(export_onnx_path)
コード例 #3
0
def test_brevitas_act_export_relu_imagenet(abits, max_val,
                                           scaling_per_channel):
    out_channels = 32
    ishape = (1, out_channels, 1, 1)
    min_val = -1.0
    b_act = QuantReLU(
        bit_width=abits,
        quant_type=QuantType.INT,
        scaling_impl_type=ScalingImplType.PARAMETER,
        scaling_per_channel=scaling_per_channel,
        restrict_scaling_type=RestrictValueType.LOG_FP,
        scaling_min_val=2e-16,
        max_val=6.0,
        return_quant_tensor=True,
        per_channel_broadcastable_shape=(1, out_channels, 1, 1),
    )
    if scaling_per_channel is True:
        rand_tensor = (2) * torch.rand((1, out_channels, 1, 1))
    else:
        rand_tensor = torch.tensor(1.2398)
    checkpoint = {
        "act_quant_proxy.fused_activation_quant_proxy.tensor_quant.\
scaling_impl.learned_value":
        rand_tensor.type(torch.FloatTensor)
    }
    b_act.load_state_dict(checkpoint)
    bo.export_finn_onnx(b_act, ishape, export_onnx_path)
    model = ModelWrapper(export_onnx_path)
    model = model.transform(InferShapes())
    inp_tensor = np.random.uniform(low=min_val, high=max_val,
                                   size=ishape).astype(np.float32)
    idict = {model.graph.input[0].name: inp_tensor}
    odict = oxe.execute_onnx(model, idict, True)
    produced = odict[model.graph.output[0].name]
    inp_tensor = torch.from_numpy(inp_tensor).float()
    b_act.eval()
    expected = b_act.forward(inp_tensor).tensor.detach().numpy()
    if not np.isclose(produced, expected, atol=1e-3).all():
        print(abits, max_val)
        print("scale: ",
              b_act.quant_act_scale().type(torch.FloatTensor).detach())
        if abits < 5:
            print(
                "thres:",
                ", ".join(["{:8.4f}".format(x)
                           for x in b_act.export_thres[0]]),
            )
        print("input:",
              ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]]))
        print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]]))
        print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]]))

    assert np.isclose(produced, expected, atol=1e-3).all()
    os.remove(export_onnx_path)
コード例 #4
0
ファイル: act.py プロジェクト: marenan/brevitas
 def thresholds(module: QuantReLU):
     num_distinct_values = 2**int(module.quant_act_bit_width().item())
     num_thresholds = num_distinct_values - 1
     flat_scale = module.quant_act_scale().view(-1)
     num_scale_channels = flat_scale.shape[0]
     step = torch.abs(flat_scale)
     min_threshold = step / 2
     thresholds = torch.empty(num_scale_channels, num_thresholds)
     for c in range(num_scale_channels):
         for t in range(num_thresholds):
             thresholds[c][t] = min_threshold[c] + step[c] * t
     return thresholds
コード例 #5
0
 def thresholds(module: QuantReLU, extend_tensor_to_channels=True):
     num_distinct_values = 2 ** int(module.quant_act_bit_width().item())
     num_thresholds = num_distinct_values - 1
     flat_scale = module.quant_act_scale().view(-1)
     num_scale_channels = flat_scale.shape[0]
     step = torch.abs(flat_scale)
     min_threshold = step / 2
     thresholds = torch.empty(num_scale_channels, num_thresholds)
     for c in range(num_scale_channels):
         for t in range(num_thresholds):
             thresholds[c][t] = min_threshold[c] + step[c] * t
     if extend_tensor_to_channels:
         output_channels = module._cached_inp.shape[1]
         final_shape = (output_channels, num_thresholds)
         if thresholds.shape != final_shape:
             thresholds = thresholds.expand(final_shape)
     return thresholds
コード例 #6
0
 def quant_act_scale(module: QuantReLU):
     quant_act_scale = module.quant_act_scale().type(torch.FloatTensor).detach()
     return quant_act_scale