def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, pretrained):
    if size == "LFC" and wbits == 2 and abits == 2:
        pytest.skip(f"No LFC_{MAX_WBITS}W{MAX_ABITS}A present.")
    if wbits > abits:
        pytest.skip("No wbits > abits cases.")
    nname = f"{size}_{wbits}W{abits}A"
    finn_onnx = nname + ".onnx"
    fc, _ = model_with_cfg(nname.lower(), pretrained=pretrained)
    FINNManager.export_onnx(fc, FC_INPUT_SIZE, finn_onnx)
    model = ModelWrapper(finn_onnx)
    model = model.transform(GiveUniqueNodeNames())
    model = model.transform(DoubleToSingleFloat())
    model = model.transform(InferShapes())
    model = model.transform(FoldConstants())
    model = model.transform(RemoveStaticGraphInputs())
    # load a random test vector
    input_tensor = np.random.uniform(MIN_INP_VAL,
                                     MAX_INP_VAL,
                                     size=FC_INPUT_SIZE).astype(np.float32)
    # run using FINN-based execution
    input_dict = {"0": input_tensor}
    output_dict = oxe.execute_onnx(model, input_dict)
    produced = output_dict[list(output_dict.keys())[0]]
    # run using PyTorch/Brevitas
    input_tensor = torch.from_numpy(input_tensor).float()
    # do forward pass in PyTorch/Brevitas
    expected = fc.forward(input_tensor).detach().numpy()
    assert np.isclose(produced, expected, atol=ATOL).all()
Example #2
0
def test_quartznet_asr_4b(pretrained):
    # inline import to make xfail work on the import error
    from brevitas_examples.speech_to_text import quant_quartznet_perchannelscaling_4b

    finn_onnx = "quant_quartznet_perchannelscaling_4b.onnx"
    quartznet = quant_quartznet_perchannelscaling_4b(pretrained, export_mode=True)
    FINNManager.export_onnx(quartznet, QUARTZNET_POSTPROCESSED_INPUT_SIZE, finn_onnx)
    model = ModelWrapper(finn_onnx)
    model = model.transform(GiveUniqueNodeNames())
    model = model.transform(DoubleToSingleFloat())
    model = model.transform(InferShapes())
    model = model.transform(FoldConstants())
    model = model.transform(RemoveStaticGraphInputs())
    #load a random test vector
    input_tensor = np.random.uniform(
        MIN_INP_VAL, MAX_INP_VAL, size=QUARTZNET_POSTPROCESSED_INPUT_SIZE).astype(np.float32)
    # run using FINN-based execution
    input_dict = {"0": input_tensor}
    output_dict = oxe.execute_onnx(model, input_dict)
    produced = output_dict[list(output_dict.keys())[0]]
    # run using PyTorch/Brevitas
    input_tensor = torch.from_numpy(input_tensor).float()
    # do forward pass in PyTorch/Brevitas
    expected = quartznet(input_tensor).detach().numpy()
    assert np.isclose(produced, expected, atol=ATOL).all()
Example #3
0
def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width,
                                  input_bit_width, channels, idim):
    ishape = (1, channels, idim, idim)
    ibw_tensor = torch.Tensor([input_bit_width])

    b_avgpool = QuantAvgPool2d(
        kernel_size=kernel_size,
        stride=stride,
        signed=signed,
        min_overall_bit_width=bit_width,
        max_overall_bit_width=bit_width,
        quant_type=QuantType.INT,
    )
    # call forward pass manually once to cache scale factor and bitwidth
    input_tensor = torch.from_numpy(np.zeros(ishape)).float()
    scale = np.ones((1, channels, 1, 1))
    output_scale = torch.from_numpy(scale).float()
    input_quant_tensor = QuantTensor(input_tensor, output_scale, ibw_tensor,
                                     signed)
    FINNManager.export_onnx(b_avgpool,
                            ishape,
                            export_onnx_path,
                            input_t=input_quant_tensor)
    model = ModelWrapper(export_onnx_path)

    # determine input FINN datatype
    if signed is True:
        prefix = "INT"
    else:
        prefix = "UINT"
    dt_name = prefix + str(input_bit_width // 2)
    dtype = DataType[dt_name]
    model = model.transform(InferShapes())
    model = model.transform(InferDataTypes())

    # execution with input tensor using integers and scale = 1
    # calculate golden output
    inp = gen_finn_dt_tensor(dtype, ishape)
    input_tensor = torch.from_numpy(inp).float()
    input_quant_tensor = QuantTensor(input_tensor, output_scale, ibw_tensor,
                                     signed)
    b_avgpool.eval()
    expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()

    # finn execution
    idict = {model.graph.input[0].name: inp}
    odict = oxe.execute_onnx(model, idict, True)
    produced = odict[model.graph.output[0].name]
    assert (expected == produced).all()

    # execution with input tensor using float and scale != 1
    scale = np.random.uniform(low=0, high=1,
                              size=(1, channels, 1, 1)).astype(np.float32)
    inp_tensor = inp * scale
    input_tensor = torch.from_numpy(inp_tensor).float()
    input_scale = torch.from_numpy(scale).float()
    input_quant_tensor = QuantTensor(input_tensor, input_scale, ibw_tensor,
                                     signed)
    # export again to set the scale values correctly
    bo.export_finn_onnx(b_avgpool,
                        ishape,
                        export_onnx_path,
                        input_t=input_quant_tensor)
    model = ModelWrapper(export_onnx_path)
    model = model.transform(InferShapes())
    model = model.transform(InferDataTypes())
    b_avgpool.eval()
    expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
    # finn execution
    idict = {model.graph.input[0].name: inp_tensor}
    odict = oxe.execute_onnx(model, idict, True)
    produced = odict[model.graph.output[0].name]

    assert np.isclose(expected, produced).all()

    os.remove(export_onnx_path)