def test_brevitas_qlinear(bias, out_features, in_features, w_bits, i_dtype): i_shape = (1, in_features) w_shape = (out_features, in_features) b_linear = QuantLinear( out_features=out_features, in_features=in_features, bias=bias, bias_quant_type=QuantType.FP, weight_bit_width=w_bits, weight_quant_type=QuantType.INT, weight_scaling_per_output_channel=True, ) weight_tensor_fp = np.random.uniform(low=-1.0, high=1.0, size=w_shape).astype(np.float32) b_linear.weight.data = torch.from_numpy(weight_tensor_fp) b_linear.eval() bo.export_finn_onnx(b_linear, i_shape, export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) inp_tensor = gen_finn_dt_tensor(i_dtype, i_shape) idict = {model.graph.input[0].name: inp_tensor} odict = oxe.execute_onnx(model, idict, True) produced = odict[model.graph.output[0].name] inp_tensor = torch.from_numpy(inp_tensor).float() expected = b_linear.forward(inp_tensor).detach().numpy() assert np.isclose(produced, expected, atol=1e-3).all() os.remove(export_onnx_path)
def test_quant_linear(bias, bias_quant, out_features, in_features, w_bits, channel_scaling, i_bits): # required to generated quantized inputs, not part of the exported model to test quant_inp = QuantIdentity(bit_width=i_bits, return_quant_tensor=True) inp_tensor = quant_inp(torch.randn(1, in_features)) linear = QuantLinear(out_features=out_features, in_features=in_features, bias=bias, bias_quant=bias_quant, weight_bit_width=w_bits, weight_scaling_per_output_channel=channel_scaling) linear.eval() model = bo.export_finn_onnx(linear, input_t=inp_tensor, export_path='linear.onnx') model = ModelWrapper(model) model = model.transform(InferShapes()) # the quantized input tensor passed to FINN should be in integer form int_inp_array = inp_tensor.int(float_datatype=True).numpy() idict = {model.graph.input[0].name: int_inp_array} odict = oxe.execute_onnx(model, idict, True) produced = odict[model.graph.output[0].name] expected = linear(inp_tensor).detach().numpy() assert np.isclose(produced, expected, atol=1e-3).all()