def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, pretrained): if size == "LFC" and wbits == 2 and abits == 2: pytest.skip(f"No LFC_{MAX_WBITS}W{MAX_ABITS}A present.") if wbits > abits: pytest.skip("No wbits > abits cases.") nname = f"{size}_{wbits}W{abits}A" finn_onnx = nname + ".onnx" fc, _ = model_with_cfg(nname.lower(), pretrained=pretrained) FINNManager.export_onnx(fc, FC_INPUT_SIZE, finn_onnx) model = ModelWrapper(finn_onnx) model = model.transform(GiveUniqueNodeNames()) model = model.transform(DoubleToSingleFloat()) model = model.transform(InferShapes()) model = model.transform(FoldConstants()) model = model.transform(RemoveStaticGraphInputs()) # load a random test vector input_tensor = np.random.uniform(MIN_INP_VAL, MAX_INP_VAL, size=FC_INPUT_SIZE).astype(np.float32) # run using FINN-based execution input_dict = {"0": input_tensor} output_dict = oxe.execute_onnx(model, input_dict) produced = output_dict[list(output_dict.keys())[0]] # run using PyTorch/Brevitas input_tensor = torch.from_numpy(input_tensor).float() # do forward pass in PyTorch/Brevitas expected = fc.forward(input_tensor).detach().numpy() assert np.isclose(produced, expected, atol=ATOL).all()
def test_quartznet_asr_4b(pretrained): # inline import to make xfail work on the import error from brevitas_examples.speech_to_text import quant_quartznet_perchannelscaling_4b finn_onnx = "quant_quartznet_perchannelscaling_4b.onnx" quartznet = quant_quartznet_perchannelscaling_4b(pretrained, export_mode=True) FINNManager.export_onnx(quartznet, QUARTZNET_POSTPROCESSED_INPUT_SIZE, finn_onnx) model = ModelWrapper(finn_onnx) model = model.transform(GiveUniqueNodeNames()) model = model.transform(DoubleToSingleFloat()) model = model.transform(InferShapes()) model = model.transform(FoldConstants()) model = model.transform(RemoveStaticGraphInputs()) #load a random test vector input_tensor = np.random.uniform( MIN_INP_VAL, MAX_INP_VAL, size=QUARTZNET_POSTPROCESSED_INPUT_SIZE).astype(np.float32) # run using FINN-based execution input_dict = {"0": input_tensor} output_dict = oxe.execute_onnx(model, input_dict) produced = output_dict[list(output_dict.keys())[0]] # run using PyTorch/Brevitas input_tensor = torch.from_numpy(input_tensor).float() # do forward pass in PyTorch/Brevitas expected = quartznet(input_tensor).detach().numpy() assert np.isclose(produced, expected, atol=ATOL).all()
def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width, input_bit_width, channels, idim): ishape = (1, channels, idim, idim) ibw_tensor = torch.Tensor([input_bit_width]) b_avgpool = QuantAvgPool2d( kernel_size=kernel_size, stride=stride, signed=signed, min_overall_bit_width=bit_width, max_overall_bit_width=bit_width, quant_type=QuantType.INT, ) # call forward pass manually once to cache scale factor and bitwidth input_tensor = torch.from_numpy(np.zeros(ishape)).float() scale = np.ones((1, channels, 1, 1)) output_scale = torch.from_numpy(scale).float() input_quant_tensor = QuantTensor(input_tensor, output_scale, ibw_tensor, signed) FINNManager.export_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) model = ModelWrapper(export_onnx_path) # determine input FINN datatype if signed is True: prefix = "INT" else: prefix = "UINT" dt_name = prefix + str(input_bit_width // 2) dtype = DataType[dt_name] model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) # execution with input tensor using integers and scale = 1 # calculate golden output inp = gen_finn_dt_tensor(dtype, ishape) input_tensor = torch.from_numpy(inp).float() input_quant_tensor = QuantTensor(input_tensor, output_scale, ibw_tensor, signed) b_avgpool.eval() expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() # finn execution idict = {model.graph.input[0].name: inp} odict = oxe.execute_onnx(model, idict, True) produced = odict[model.graph.output[0].name] assert (expected == produced).all() # execution with input tensor using float and scale != 1 scale = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(np.float32) inp_tensor = inp * scale input_tensor = torch.from_numpy(inp_tensor).float() input_scale = torch.from_numpy(scale).float() input_quant_tensor = QuantTensor(input_tensor, input_scale, ibw_tensor, signed) # export again to set the scale values correctly bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) b_avgpool.eval() expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() # finn execution idict = {model.graph.input[0].name: inp_tensor} odict = oxe.execute_onnx(model, idict, True) produced = odict[model.graph.output[0].name] assert np.isclose(expected, produced).all() os.remove(export_onnx_path)