Exemplo n.º 1
0
def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, pretrained):
    if size == "LFC" and wbits == 2 and abits == 2:
        pytest.skip(f"No LFC_{MAX_WBITS}W{MAX_ABITS}A present.")
    if wbits > abits:
        pytest.skip("No wbits > abits cases.")
    nname = f"{size}_{wbits}W{abits}A"
    finn_onnx = nname + ".onnx"
    fc, _ = model_with_cfg(nname.lower(), pretrained=pretrained)
    FINNManager.export_onnx(fc, FC_INPUT_SIZE, finn_onnx)
    model = ModelWrapper(finn_onnx)
    model = model.transform(GiveUniqueNodeNames())
    model = model.transform(DoubleToSingleFloat())
    model = model.transform(InferShapes())
    model = model.transform(FoldConstants())
    model = model.transform(RemoveStaticGraphInputs())
    # load a random test vector
    input_tensor = np.random.uniform(MIN_INP_VAL,
                                     MAX_INP_VAL,
                                     size=FC_INPUT_SIZE).astype(np.float32)
    # run using FINN-based execution
    input_dict = {"0": input_tensor}
    output_dict = oxe.execute_onnx(model, input_dict)
    produced = output_dict[list(output_dict.keys())[0]]
    # run using PyTorch/Brevitas
    input_tensor = torch.from_numpy(input_tensor).float()
    # do forward pass in PyTorch/Brevitas
    expected = fc.forward(input_tensor).detach().numpy()
    assert np.isclose(produced, expected, atol=ATOL).all()
def test_brevitas_cnv_onnx_export_and_exec(wbits, abits, pretrained):
    if wbits > abits:
        pytest.skip("No wbits > abits cases.")
    nname = f"CNV_{wbits}W{abits}A"
    finn_onnx = nname + ".onnx"
    cnv, _ = model_with_cfg(nname.lower(), pretrained=pretrained)
    cnv.eval()
    # load a random int test vector
    input_a = np.random.randint(MIN_INP_VAL, MAX_INP_VAL, size=CNV_INPUT_SIZE).astype(np.float32)
    scale = 1. / 255
    input_t = torch.from_numpy(input_a * scale)
    input_qt = QuantTensor(
        input_t, scale=torch.tensor(scale), bit_width=torch.tensor(8.0), signed=False)
    FINNManager.export(cnv, export_path=finn_onnx, input_t=input_qt)
    model = ModelWrapper(finn_onnx)
    model = model.transform(GiveUniqueNodeNames())
    model = model.transform(DoubleToSingleFloat())
    model = model.transform(InferShapes())
    model = model.transform(FoldConstants())
    model = model.transform(RemoveStaticGraphInputs())
    # run using FINN-based execution
    input_dict = {"0": input_a}
    output_dict = oxe.execute_onnx(model, input_dict)
    produced = output_dict[list(output_dict.keys())[0]]
    # do forward pass in PyTorch/Brevitas
    expected = cnv(input_t).detach().numpy()
    assert np.isclose(produced, expected, atol=ATOL).all()
Exemplo n.º 3
0
def test_debug_finn_onnx_export():
    model, cfg = model_with_cfg(REF_MODEL, pretrained=False)
    debug_hook = enable_debug(model)
    input_tensor = torch.randn(1, 3, 32, 32)
    export_finn_onnx(model,
                     input_shape=input_tensor.shape,
                     export_path='debug.onnx')
    model(input_tensor)
    assert debug_hook.values
Exemplo n.º 4
0
def test_brevitas_cnv_jit_trace(wbits, abits):
    if wbits > abits:
        pytest.skip("No wbits > abits cases.")
    nname = f"CNV_{wbits}W{abits}A"
    cnv, _ = model_with_cfg(nname.lower(), pretrained=False)
    cnv.train(False)
    input_tensor = torch.randn(CNV_INPUT_SIZE)
    traced_model = jit_trace_patched(cnv, input_tensor)
    out_traced = traced_model(input_tensor)
    out = cnv(input_tensor)
    assert out.isclose(out_traced).all().item()
Exemplo n.º 5
0
def test_brevitas_fc_jit_trace(size, wbits, abits):
    if size == "LFC" and wbits == 2 and abits == 2:
        pytest.skip(f"No LFC_{MAX_WBITS}W{MAX_ABITS}A present.")
    if wbits > abits:
        pytest.skip("No wbits > abits cases.")
    nname = f"{size}_{wbits}W{abits}A"
    fc, _ = model_with_cfg(nname.lower(), pretrained=False)
    fc.train(False)
    input_tensor = torch.randn(FC_INPUT_SIZE)
    traced_model = jit_trace_patched(fc, input_tensor)
    out_traced = traced_model(input_tensor)
    out = fc(input_tensor)
    assert out.isclose(out_traced).all().item()
Exemplo n.º 6
0
def test_brevitas_cnv_jit_trace(wbits, abits):
    if wbits > abits:
        pytest.skip("No wbits > abits cases.")
    nname = f"CNV_{wbits}W{abits}A"
    cnv, _ = model_with_cfg(nname.lower(), pretrained=False)
    cnv.train(False)
    input_tensor = torch.randn(CNV_INPUT_SIZE)
    with ExitStack() as stack:
        for mgr in jit_patches_generator():
            stack.enter_context(mgr)
    traced_model = torch.jit.trace(cnv, input_tensor)
    out_traced = traced_model(input_tensor)
    out = cnv(input_tensor)
    assert out.isclose(out_traced).all().item()