Beispiel #1
0
 def forward(self, x: torch.Tensor) -> QuantTensor:
     if self.is_quant_enabled:
         out, scale, zero_point, bit_width = self.tensor_quant(x)
         return QuantTensor(out, scale, zero_point, bit_width,
                            self.is_signed, self.training)
     else:  # quantization disabled
         return QuantTensor(x, training=self.training)
Beispiel #2
0
 def forward(self, x: torch.Tensor) -> QuantTensor:
     if self.is_quant_enabled:
         impl = self.export_handler if self.export_mode else self.tensor_quant
         out, pre_scale, pre_zero_point, scale, zero_point, bit_width = impl(x)
         return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
     else:  # quantization disabled
         return QuantTensor(x, training=self.training)
Beispiel #3
0
 def forward(self,
             x: Tensor,
             input_scale: Optional[Tensor] = None,
             input_bit_width: Optional[Tensor] = None) -> QuantTensor:
     if self.is_quant_enabled:
         if self.requires_input_scale and input_scale is None:
             raise RuntimeError("Input scale required")
         if self.requires_input_bit_width and input_bit_width is None:
             raise RuntimeError("Input bit-width required")
         if self.requires_input_scale and self.requires_input_bit_width:
             input_scale = input_scale.view(-1)
             out, out_scale, out_bit_width, out_zp = self.tensor_quant(
                 x, input_scale, input_bit_width)
         elif self.requires_input_scale and not self.requires_input_bit_width:
             input_scale = input_scale.view(-1)
             out, out_scale, out_bit_width, out_zp = self.tensor_quant(
                 x, input_scale)
         elif not self.requires_input_scale and not self.requires_input_bit_width:
             out, out_scale, out_bit_width, out_zp = self.tensor_quant(x)
         else:
             raise RuntimeError("Internally defined bit-width required")
         return QuantTensor(out, out_scale, out_bit_width, out_zp,
                            self.is_signed, self.training)
     else:
         return QuantTensor(x, training=self.training)
Beispiel #4
0
 def forward(self, x: QuantTensor):
     if self.is_quant_enabled:
         cleaned_up_value = round_ste(
             x.value / x.scale.detach()) * x.scale.detach()
         x = x.set(value=cleaned_up_value
                   )  # clean up accumulated floating point errors
         trunc_bit_width = self.lsb_trunc_bit_width_impl(x.bit_width)
         trunc_scale = 2.0**trunc_bit_width
         output_scale = trunc_scale * x.scale
         if self.training:
             x, output_scale, x_bit_width = self.tensor_quant(
                 x.value, output_scale, x.bit_width)
         else:  # avoid fp errors at inference time
             x_bit_width = x.bit_width
             x = round_ste(x.value / x.scale)
             x = x / trunc_scale
             x = self.tensor_quant.int_quant.float_to_int_impl(x)
             x = x * output_scale
         x = x / trunc_scale
         output_scale = output_scale / trunc_scale  # output_scale == input_scale
         output_bit_width = x_bit_width - trunc_bit_width
         return QuantTensor(x, output_scale, output_bit_width,
                            self.is_signed)
     else:
         return x
Beispiel #5
0
 def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor:
     if self.is_act_enabled or self.is_quant_enabled:
         if isinstance(x, QuantTensor):
             x = x.value
         x = self.fused_activation_quant_proxy(x)
         return QuantTensor(*x, signed=self.is_signed)
     else:
         if isinstance(x, QuantTensor):  # passthrough
             return x
         else:
             return QuantTensor(x)
Beispiel #6
0
 def pack_output(self, quant_output: QuantTensor):
     if not self.training and self.cache_inference_quant_out:
         self._cached_out = _CachedIO(quant_output.detach(), self.cache_quant_io_metadata_only)
     if self.return_quant_tensor:
         return quant_output
     else:
         return quant_output.value
Beispiel #7
0
 def forward(self, x: QuantTensor):
     if self.is_quant_enabled:
         out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width)
         out_value, out_scale, out_zp, out_bit_width = out_tuple
         return QuantTensor(out_value, out_scale, out_zp, out_bit_width,
                            self.is_signed, self.training)
     return x
Beispiel #8
0
 def forward(self, x: QuantTensor):
     if self.is_quant_enabled:
         out_value, out_scale, out_bit_width = self.tensor_quant(
             x.value, x.scale, x.bit_width)
         return QuantTensor(out_value, out_scale, out_bit_width, x.signed)
     else:
         return x
Beispiel #9
0
 def forward(self, inp):
     output_scale = None
     output_zp = None
     output_bit_width = None
     inp = self.unpack_input(inp)
     norm = inp.value.norm(p='fro', keepdim=True) + self.eps
     out = inp.value / norm
     out = nn.functional.linear(
         out, self.proj[:self.out_channels, :self.in_channels])
     out = -self.scale * out
     if inp.scale is not None:
         output_scale = inp.scale * self.scale / norm
     if inp.bit_width is not None:
         output_bit_width = self.max_output_bit_width(inp.bit_width)
     if (self.return_quant_tensor and inp.zero_point is not None
             and (inp.zero_point != 0.0).any()):
         raise RuntimeError(
             "Computing zero point of output accumulator not supported yet."
         )
     else:
         output_zp = inp.zero_point
     out = QuantTensor(value=out,
                       scale=output_scale,
                       zero_point=output_zp,
                       bit_width=output_bit_width,
                       signed=True,
                       training=self.training)
     return out
Beispiel #10
0
 def pack_output(self, output, output_scale, output_bit_width):
     if self.return_quant_tensor:
         return QuantTensor(tensor=output,
                            scale=output_scale,
                            bit_width=output_bit_width)
     else:
         return output
def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, pretrained):
    if size == "LFC" and wbits == 2 and abits == 2:
        pytest.skip(f"No LFC_{MAX_WBITS}W{MAX_ABITS}A present.")
    if wbits > abits:
        pytest.skip("No wbits > abits cases.")
    nname = f"{size}_{wbits}W{abits}A"
    finn_onnx = nname + ".onnx"
    fc, _ = model_with_cfg(nname.lower(), pretrained=pretrained)
    fc.eval()
    # load a random int test vector
    input_a = np.random.randint(MIN_INP_VAL, MAX_INP_VAL, size=FC_INPUT_SIZE).astype(np.float32)
    scale = 1. / 255
    input_t = torch.from_numpy(input_a * scale)
    input_qt = QuantTensor(
        input_t, scale=torch.tensor(scale), bit_width=torch.tensor(8.0), signed=False)
    FINNManager.export(fc, export_path=finn_onnx, input_t=input_qt)
    model = ModelWrapper(finn_onnx)
    model = model.transform(GiveUniqueNodeNames())
    model = model.transform(DoubleToSingleFloat())
    model = model.transform(InferShapes())
    model = model.transform(FoldConstants())
    model = model.transform(RemoveStaticGraphInputs())
    # run using FINN-based execution
    input_dict = {"0": input_a}
    output_dict = oxe.execute_onnx(model, input_dict)
    produced = output_dict[list(output_dict.keys())[0]]
    # do forward pass in PyTorch/Brevitas
    expected = fc.forward(input_t).detach().numpy()
    assert np.isclose(produced, expected, atol=ATOL).all()
Beispiel #12
0
 def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor:
     if self.is_act_enabled or self.is_quant_enabled:
         y = x
         if isinstance(y, QuantTensor):
             y = y.value
         y = self.fused_activation_quant_proxy(y)
         if isinstance(y, tuple):
             return QuantTensor(*y, signed=self.is_signed)
         elif self.passthrough_act:  # preserve scale/bit/sign even without output quant
             return QuantTensor(y, x.scale, x.bit_width, x.signed)
         else:
             return QuantTensor(y)
     else:
         if isinstance(x, QuantTensor):  # passthrough
             return x
         else:
             return QuantTensor(x)
Beispiel #13
0
 def test_forward_bias_int(self):
     mod = QuantLinear(
         out_features=OUTPUT_FEATURES,
         in_features=INPUT_FEATURES,
         bias=True,
         bias_quant_type='INT')
     x = QuantTensor(torch.rand(size=(3, INPUT_FEATURES)), torch.tensor(1.0), torch.tensor(3))
     assert mod(x) is not None
Beispiel #14
0
 def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor:
     if self.fused_activation_quant_proxy is not None:
         y = x
         if isinstance(y, QuantTensor):
             y = y.value
         y = self.fused_activation_quant_proxy(y)
         if isinstance(y, tuple):
             return QuantTensor(*y, signed=self.is_signed, training=self.training)
         elif self.is_passthrough_act:  # preserve scale/zp/bit/sign even without output quant
             return QuantTensor(
                 y, x.scale, x.zero_point, x.bit_width, x.signed, self.training)
         else:
             return QuantTensor(y, training=self.training)
     else:
         if isinstance(x, QuantTensor):  # passthrough
             return x
         else:
             return QuantTensor(x, training=self.training)
Beispiel #15
0
 def forward(
         self,
         x: Tensor,
         input_scale: Tensor,
         input_bit_width: Optional[Tensor]) -> QuantTensor:
     if self.is_quant_enabled:
         if input_scale is None:
             raise RuntimeError("Input scale can't be None when quantizing bias")
         input_scale = input_scale.view(-1)
         if self.requires_input_bit_width:  # bit width is defined outside
             if input_bit_width is None:
                 raise RuntimeError("Input or predefined bit width required")
             out, out_scale, out_bit_width = self.tensor_quant(x, input_scale, input_bit_width)
         else:
             out, out_scale, out_bit_width = self.tensor_quant(x, input_scale)
         return QuantTensor(out, out_scale, out_bit_width, self.is_signed)
     else:
         return QuantTensor(x)
Beispiel #16
0
 def forward(self, x: QuantTensor):
     if self.is_quant_enabled:
         impl = self.export_handler if self.export_mode else self.tensor_quant
         out_tuple = impl(x.value, x.scale, x.zero_point, x.bit_width)
         out_value, out_scale, out_zp, out_bit_width = out_tuple
         return QuantTensor(out_value, out_scale, out_zp, out_bit_width,
                            x.signed, self.training)
     else:
         return x
Beispiel #17
0
 def forward(self,
             tensor_list: Union[List[Tensor], List[QuantTensor]],
             dim: int = 1) -> Union[Tensor, QuantTensor]:
     quant_tensor_list = [self.unpack_input(t) for t in tensor_list]
     # shortcut execution through the export impl during export
     if self.export_mode:
         return self.export_handler([qt.value for qt in quant_tensor_list])
     quant_tensor_list = [self.input_quant(qt) for qt in quant_tensor_list]
     # trigger an assert if scale factors and bit widths are None or different
     output = QuantTensor.cat(quant_tensor_list, dim=dim)
     quant_output = self.output_quant(output)
     return self.pack_output(quant_output)
Beispiel #18
0
def test_export():
    x = QuantTensor(torch.randn(IN_SIZE),
                    scale=torch.tensor(2.0**(-7)),
                    bit_width=torch.tensor(8),
                    signed=True)
    mod = QuantModel()
    # Export quantized model to ONNX
    export_dpuv1_onnx(mod,
                      input_shape=IN_SIZE,
                      input_t=x,
                      export_path='quant_model.onnx',
                      input_names=["input_%d" % i for i in range(5)],
                      output_names=["output"])
Beispiel #19
0
    def forward_impl(
            self, inp: Union[Tensor,
                             QuantTensor]) -> Union[Tensor, QuantTensor]:
        output_scale = None
        output_bit_width = None

        inp = self.unpack_input(inp)

        # shortcut execution through the export impl during export
        if self.export_mode:
            return self.export_handler(inp.value)

        quant_input = self.input_quant(inp)
        quant_weight = self.quant_weight()

        if quant_input.bit_width is not None:
            output_bit_width = self.max_acc_bit_width(quant_input.bit_width,
                                                      quant_weight.bit_width)
        if quant_input.scale is not None:
            output_scale_shape = compute_channel_view_shape(inp, channel_dim=1)
            output_scale = quant_weight.scale.view(output_scale_shape)
            output_scale = output_scale * quant_input.scale.view(
                output_scale_shape)

        if self.bias is not None:
            quant_bias = self.bias_quant(self.bias, output_scale,
                                         output_bit_width)
            if not self.training and self.cache_inference_quant_bias:
                self._cached_bias = _CachedIO(quant_bias.detach(),
                                              metadata_only=False)

            output_tensor = self.inner_forward_impl(quant_input.value,
                                                    quant_weight.value,
                                                    quant_bias.value)

            if quant_bias.bit_width is not None and output_bit_width is not None:
                output_bit_width = torch.where(
                    quant_bias.bit_width > output_bit_width,
                    quant_bias.bit_width, output_bit_width)
                output_bit_width = output_bit_width + 1
        else:
            output_tensor = self.inner_forward_impl(quant_input.value,
                                                    quant_weight.value, None)

        quant_output = QuantTensor(output_tensor,
                                   output_scale,
                                   output_bit_width,
                                   signed=True)
        quant_output = self.output_quant(quant_output)
        return self.pack_output(quant_output)
Beispiel #20
0
 def unpack_input(self, inp: Union[Tensor, QuantTensor]):
     if isinstance(inp, QuantTensor):
         if self.export_mode:
             raise RuntimeError("QuantTensor I/O can't be used during export.")
         if not self.training and self.cache_inference_quant_inp:
             cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
             self._cached_inp = cached_inp
         return inp
     else:
         inp = QuantTensor(inp)
         if not self.training and self.cache_inference_quant_inp:
             cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
             self._cached_inp = cached_inp
         return inp
Beispiel #21
0
 def unpack_input(self, inp: Union[Tensor, QuantTensor]):
     self._set_global_is_quant_layer(True)
     # Hack to recognize a QuantTensor that has decayed to a tuple
     # when used as input to tracing (e.g. during ONNX export)
     if (torch._C._get_tracing_state() is not None
             and isinstance(inp, tuple)
             and len(inp) == len(QuantTensor._fields)
             and all([isinstance(t, Tensor) for t in inp])):
         inp = QuantTensor(*inp)
     if isinstance(inp, QuantTensor):
         # don't cache values during export pass
         if not self.training and not self._export_mode and self.cache_inference_quant_inp:
             cached_inp = _CachedIO(inp.detach(),
                                    self.cache_quant_io_metadata_only)
             self._cached_inp = cached_inp
         return inp
     else:
         inp = QuantTensor(inp, training=self.training)
         if not self.training and self.cache_inference_quant_inp:
             cached_inp = _CachedIO(inp.detach(),
                                    self.cache_quant_io_metadata_only)
             self._cached_inp = cached_inp
         return inp
Beispiel #22
0
 def unpack_input(self, inp: Union[Tensor, QuantTensor]):
     if isinstance(inp, QuantTensor):
         # don't cache values during export pass
         if not self.training and not self._export_mode and self.cache_inference_quant_inp:
             cached_inp = _CachedIO(inp.detach(),
                                    self.cache_quant_io_metadata_only)
             self._cached_inp = cached_inp
         return inp
     else:
         inp = QuantTensor(inp)
         if not self.training and self.cache_inference_quant_inp:
             cached_inp = _CachedIO(inp.detach(),
                                    self.cache_quant_io_metadata_only)
             self._cached_inp = cached_inp
         return inp
Beispiel #23
0
 def pack_output(self, output, output_scale, output_bit_width):
     if self._export_mode:
         # do not ever return QuantTensor while exporting
         # cached scale factors will be used in the next layer
         return output
     else:
         self.export_out_shape = output.shape
         # TODO control caching with own config variable
         self.export_out_scale = output_scale
         self.export_out_bit_width = output_bit_width
         if self.return_quant_tensor:
             return QuantTensor(tensor=output,
                                scale=output_scale,
                                bit_width=output_bit_width)
         else:
             return output
Beispiel #24
0
def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width,
                                  input_bit_width, channels, idim):

    quant_avgpool = QuantAvgPool2d(kernel_size=kernel_size,
                                   stride=stride,
                                   bit_width=bit_width)
    quant_avgpool.eval()

    # determine input
    prefix = 'INT' if signed else 'UINT'
    dt_name = prefix + str(input_bit_width)
    dtype = DataType[dt_name]
    input_shape = (1, channels, idim, idim)
    input_array = gen_finn_dt_tensor(dtype, input_shape)
    scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1,
                                                         1)).astype(np.float32)
    input_tensor = torch.from_numpy(input_array * scale_array).float()
    scale_tensor = torch.from_numpy(scale_array).float()
    zp = torch.tensor(0.)
    input_quant_tensor = QuantTensor(input_tensor,
                                     scale_tensor,
                                     zp,
                                     input_bit_width,
                                     signed,
                                     training=False)

    # export
    FINNManager.export(quant_avgpool,
                       export_path=export_onnx_path,
                       input_t=input_quant_tensor)
    model = ModelWrapper(export_onnx_path)
    model = model.transform(InferShapes())
    model = model.transform(InferDataTypes())

    # reference brevitas output
    ref_output_array = quant_avgpool(
        input_quant_tensor).tensor.detach().numpy()
    # finn output
    idict = {model.graph.input[0].name: input_array}
    odict = oxe.execute_onnx(model, idict, True)
    finn_output = odict[model.graph.output[0].name]
    # compare outputs
    assert np.isclose(ref_output_array, finn_output).all()
    # cleanup
    os.remove(export_onnx_path)
Beispiel #25
0
 def __init__(self, quant_tensor: QuantTensor, metadata_only: bool):
     self.shape = quant_tensor.value.shape
     if metadata_only:
         self.quant_tensor = quant_tensor.set(value=None)
     else:
         self.quant_tensor = quant_tensor
Beispiel #26
0
 def forward(self, x: torch.Tensor) -> QuantTensor:
     if self.is_quant_enabled:
         out, scale, bit_width = self.tensor_quant(x)
         return QuantTensor(out, scale, bit_width, signed=self.is_signed)
     else:  # quantization disabled
         return QuantTensor(x)
Beispiel #27
0
def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width,
                                  input_bit_width, channels, idim):
    ishape = (1, channels, idim, idim)
    ibw_tensor = torch.Tensor([input_bit_width])

    b_avgpool = QuantAvgPool2d(
        kernel_size=kernel_size,
        stride=stride,
        signed=signed,
        min_overall_bit_width=bit_width,
        max_overall_bit_width=bit_width,
        quant_type=QuantType.INT,
    )
    # call forward pass manually once to cache scale factor and bitwidth
    input_tensor = torch.from_numpy(np.zeros(ishape)).float()
    scale = np.ones((1, channels, 1, 1))
    output_scale = torch.from_numpy(scale).float()
    input_quant_tensor = QuantTensor(input_tensor, output_scale, ibw_tensor,
                                     signed)
    FINNManager.export_onnx(b_avgpool,
                            ishape,
                            export_onnx_path,
                            input_t=input_quant_tensor)
    model = ModelWrapper(export_onnx_path)

    # determine input FINN datatype
    if signed is True:
        prefix = "INT"
    else:
        prefix = "UINT"
    dt_name = prefix + str(input_bit_width // 2)
    dtype = DataType[dt_name]
    model = model.transform(InferShapes())
    model = model.transform(InferDataTypes())

    # execution with input tensor using integers and scale = 1
    # calculate golden output
    inp = gen_finn_dt_tensor(dtype, ishape)
    input_tensor = torch.from_numpy(inp).float()
    input_quant_tensor = QuantTensor(input_tensor, output_scale, ibw_tensor,
                                     signed)
    b_avgpool.eval()
    expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()

    # finn execution
    idict = {model.graph.input[0].name: inp}
    odict = oxe.execute_onnx(model, idict, True)
    produced = odict[model.graph.output[0].name]
    assert (expected == produced).all()

    # execution with input tensor using float and scale != 1
    scale = np.random.uniform(low=0, high=1,
                              size=(1, channels, 1, 1)).astype(np.float32)
    inp_tensor = inp * scale
    input_tensor = torch.from_numpy(inp_tensor).float()
    input_scale = torch.from_numpy(scale).float()
    input_quant_tensor = QuantTensor(input_tensor, input_scale, ibw_tensor,
                                     signed)
    # export again to set the scale values correctly
    bo.export_finn_onnx(b_avgpool,
                        ishape,
                        export_onnx_path,
                        input_t=input_quant_tensor)
    model = ModelWrapper(export_onnx_path)
    model = model.transform(InferShapes())
    model = model.transform(InferDataTypes())
    b_avgpool.eval()
    expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy()
    # finn execution
    idict = {model.graph.input[0].name: inp_tensor}
    odict = oxe.execute_onnx(model, idict, True)
    produced = odict[model.graph.output[0].name]

    assert np.isclose(expected, produced).all()

    os.remove(export_onnx_path)
Beispiel #28
0
 def bit_width(self):
     zhs = self._zero_hw_sentinel()
     empty_imp = QuantTensor(zhs, zhs, zhs, zhs)
     bit_width = self.__call__(empty_imp).bit_width
     return bit_width
Beispiel #29
0
    def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
        output_scale = None
        output_bit_width = None
        output_zero_point = None
        output_signed = None

        inp = self.unpack_input(inp)

        # shortcut execution through the export impl during export
        if self.export_mode:
            out = self.export_handler(inp.value)
            self._set_global_is_quant_layer(False)
            return out

        quant_input = self.input_quant(inp)
        quant_weight = self.quant_weight()

        if quant_input.bit_width is not None:
            output_bit_width = self.max_acc_bit_width(quant_input.bit_width, quant_weight.bit_width)
        if quant_input.scale is not None:
            output_scale_shape = compute_channel_view_shape(inp, channel_dim=1)
            output_scale = quant_weight.scale.view(output_scale_shape)
            output_scale = output_scale * quant_input.scale.view(output_scale_shape)
        if quant_input.signed is not None:
            output_signed = inp.signed or quant_weight.signed

        if self.bias is not None:
            quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width)
            if not self.training and self.cache_inference_quant_bias:
                self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False)

            output_tensor = self.inner_forward_impl(
                quant_input.value, quant_weight.value, quant_bias.value)

            if (output_scale is not None
                    and (quant_bias.scale is None
                         or (quant_bias.scale is not None
                             and quant_bias.scale.data_ptr() != output_scale.data_ptr()))):
                output_zero_point = - quant_bias.value.view(output_scale_shape) / output_scale

            if quant_bias.bit_width is not None and output_bit_width is not None:
                output_bit_width = torch.where(
                    quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width)
                output_bit_width = output_bit_width + 1
        else:
            output_tensor = self.inner_forward_impl(quant_input.value, quant_weight.value, None)

        if self.return_quant_tensor and not self.is_output_quant_enabled:
            if (quant_input.zero_point is not None
                    and ((quant_input.zero_point != 0.0).any()
                         or (quant_weight.zero_point != 0.0).any())):
                raise RuntimeError("Computing zero point of output accumulator not supported yet.")
            elif quant_input.zero_point is not None and output_zero_point is None:
                output_zero_point = quant_input.zero_point

        quant_output = QuantTensor(
            value=output_tensor,
            scale=output_scale,
            zero_point=output_zero_point,
            bit_width=output_bit_width,
            signed=output_signed,
            training=self.training)
        quant_output = self.output_quant(quant_output)
        return self.pack_output(quant_output)
Beispiel #30
0
def test_brevitas_avg_pool_export(
    kernel_size,
    stride,
    signed,
    bit_width,
    input_bit_width,
    channels,
    idim,
    QONNX_export,
):
    export_onnx_path = base_export_onnx_path.replace(
        ".onnx", f"test_QONNX-{QONNX_export}.onnx"
    )
    quant_avgpool = QuantAvgPool2d(
        kernel_size=kernel_size,
        stride=stride,
        bit_width=bit_width,
        return_quant_tensor=False,
    )
    quant_avgpool.eval()

    # determine input
    prefix = "INT" if signed else "UINT"
    dt_name = prefix + str(input_bit_width)
    dtype = DataType[dt_name]
    input_shape = (1, channels, idim, idim)
    input_array = gen_finn_dt_tensor(dtype, input_shape)
    # Brevitas QuantAvgPool layers need QuantTensors to export correctly
    # which requires setting up a QuantTensor instance with the scale
    # factor, zero point, bitwidth and signedness
    scale_array = np.ones((1, channels, 1, 1)).astype(np.float32)
    scale_array *= 0.5
    input_tensor = torch.from_numpy(input_array * scale_array).float()
    scale_tensor = torch.from_numpy(scale_array).float()
    zp = torch.tensor(0.0)
    input_quant_tensor = QuantTensor(
        input_tensor, scale_tensor, zp, input_bit_width, signed, training=False
    )

    # export
    if QONNX_export:
        BrevitasONNXManager.export(
            quant_avgpool,
            export_path=export_onnx_path,
            input_t=input_quant_tensor,
        )
        model = ModelWrapper(export_onnx_path)

        # Statically set the additional inputs generated by the BrevitasONNXManager
        model.graph.input.remove(model.graph.input[3])
        model.graph.input.remove(model.graph.input[2])
        model.graph.input.remove(model.graph.input[1])
        model.set_initializer("1", scale_array)
        model.set_initializer("2", np.array(0.0).astype(np.float32))
        model.set_initializer("3", np.array(input_bit_width).astype(np.float32))
        model.save(export_onnx_path)

        qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
        model = ModelWrapper(export_onnx_path)
        model = model.transform(ConvertQONNXtoFINN())
        model.save(export_onnx_path)
    else:
        FINNManager.export(
            quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
        )
    model = ModelWrapper(export_onnx_path)
    model = model.transform(InferShapes())
    model = model.transform(InferDataTypes())

    # reference brevitas output
    ref_output_array = quant_avgpool(input_quant_tensor).detach().numpy()
    # finn output
    if QONNX_export:
        # Manually apply the Quant tensor scaling for QONNX
        idict = {model.graph.input[0].name: input_array * scale_array}
    else:
        idict = {model.graph.input[0].name: input_array}
    odict = oxe.execute_onnx(model, idict, True)
    finn_output = odict[model.graph.output[0].name]
    # compare outputs
    assert np.isclose(ref_output_array, finn_output).all()
    # cleanup
    # assert False
    os.remove(export_onnx_path)