Ejemplo n.º 1
0
def test_generic_quant_conv_export():
    IN_SIZE = (2, IN_CH, IN_CH, IN_CH)

    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = QuantConv2d(out_channels=OUT_CH,
                                    in_channels=IN_CH,
                                    bias=True,
                                    kernel_size=3,
                                    input_quant=Int8ActPerTensorFloat,
                                    output_quant=Int8ActPerTensorFloat,
                                    bias_quant=Int16Bias,
                                    return_quant_tensor=False)
            self.conv.weight.data.uniform_(-0.01, 0.01)

        def forward(self, x):
            return self.conv(x)

    inp = torch.randn(IN_SIZE)
    model = Model()
    model(inp)  # collect scale factors
    model.eval()
    BrevitasONNXManager.export(model,
                               input_t=inp,
                               export_path='./generic_quant_conv.onnx')
Ejemplo n.º 2
0
def test_generic_quant_tensor_export():
    IN_SIZE = (2, IN_CH)

    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.quant_inp = QuantIdentity(return_quant_tensor=True)
            self.linear = QuantLinear(out_features=OUT_CH,
                                      in_features=IN_CH,
                                      bias=True,
                                      output_quant=Int8ActPerTensorFloat,
                                      bias_quant=Int16Bias,
                                      return_quant_tensor=False)
            self.linear.weight.data.uniform_(-0.01, 0.01)

        def forward(self, x):
            return self.linear(self.quant_inp(x))

    inp = torch.randn(IN_SIZE)
    model = Model()
    model(inp)  # collect scale factors
    model.eval()
    BrevitasONNXManager.export(model,
                               input_t=inp,
                               export_path='./generic_quant_tensor.onnx')
Ejemplo n.º 3
0
 def test_export(self, topology, wbits, abits, QONNX_export):
     if wbits > abits:
         pytest.skip("No wbits > abits end2end network configs for now")
     if topology == "lfc" and not (wbits == 1 and abits == 1):
         pytest.skip("Skipping certain lfc configs")
     (model,
      ishape) = get_trained_network_and_ishape(topology, wbits, abits)
     chkpt_name = get_checkpoint_name(topology, wbits, abits, QONNX_export,
                                      "export")
     if QONNX_export:
         BrevitasONNXManager.export(model, ishape, chkpt_name)
         qonnx_cleanup(chkpt_name, out_file=chkpt_name)
         model = ModelWrapper(chkpt_name)
         model = model.transform(ConvertQONNXtoFINN())
         model.save(chkpt_name)
     else:
         bo.export_finn_onnx(model, ishape, chkpt_name)
     nname = "%s_w%da%d" % (topology, wbits, abits)
     update_dashboard_data(topology, wbits, abits, "network", nname)
     dtstr = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
     update_dashboard_data(topology, wbits, abits, "datetime", dtstr)
     finn_commit = subprocess.check_output(["git", "rev-parse", "HEAD"],
                                           cwd="/workspace/finn")
     finn_commit = finn_commit.decode("utf-8").strip()
     update_dashboard_data(topology, wbits, abits, "finn-commit",
                           finn_commit)
     assert os.path.isfile(chkpt_name)
Ejemplo n.º 4
0
def test_brevitas_cnv_export_exec(wbits, abits, QONNX_export):
    if wbits > abits:
        pytest.skip("No wbits > abits cases at the moment")
    cnv = get_test_model_trained("CNV", wbits, abits)
    ishape = (1, 3, 32, 32)
    if QONNX_export:
        BrevitasONNXManager.export(cnv, ishape, export_onnx_path)
        qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
        model = ModelWrapper(export_onnx_path)
        model = model.transform(ConvertQONNXtoFINN())
        model.save(export_onnx_path)
    else:
        bo.export_finn_onnx(cnv, ishape, export_onnx_path)
    model = ModelWrapper(export_onnx_path)
    model = model.transform(GiveUniqueNodeNames())
    model = model.transform(InferShapes())
    model = model.transform(FoldConstants())
    model = model.transform(RemoveStaticGraphInputs())
    assert len(model.graph.input) == 1
    assert len(model.graph.output) == 1
    fn = pk.resource_filename("finn.qnn-data", "cifar10/cifar10-test-data-class3.npz")
    input_tensor = np.load(fn)["arr_0"].astype(np.float32)
    input_tensor = input_tensor / 255
    assert input_tensor.shape == (1, 3, 32, 32)
    # run using FINN-based execution
    input_dict = {model.graph.input[0].name: input_tensor}
    output_dict = oxe.execute_onnx(model, input_dict, True)
    produced = output_dict[model.graph.output[0].name]
    # do forward pass in PyTorch/Brevitas
    input_tensor = torch.from_numpy(input_tensor).float()
    expected = cnv.forward(input_tensor).detach().numpy()
    assert np.isclose(produced, expected, atol=1e-3).all()
    assert np.argmax(produced) == 3
    os.remove(export_onnx_path)
Ejemplo n.º 5
0
def test_brevitas_QConv2d(dw, bias, in_channels, QONNX_export):
    ishape = (1, 32, 111, 111)
    if dw is True:
        groups = in_channels
        out_channels = in_channels
        kernel_size = 3
        padding = 1
        stride = 1
        w_shape = (32, 1, 3, 3)

    else:
        groups = 1
        out_channels = 64
        kernel_size = 1
        padding = 0
        stride = 1
        w_shape = (64, 32, 1, 1)

    b_conv = QuantConv2d(
        in_channels=in_channels,
        out_channels=out_channels,
        groups=groups,
        kernel_size=kernel_size,
        padding=padding,
        stride=stride,
        bias=bias,
        bias_quant_type=QuantType.FP,
        weight_bit_width=4,
        weight_quant_type=QuantType.INT,
        weight_scaling_impl_type=ScalingImplType.STATS,
        weight_scaling_stats_op=StatsOp.MAX,
        weight_scaling_per_output_channel=True,
        weight_restrict_scaling_type=RestrictValueType.LOG_FP,
        weight_narrow_range=True,
        weight_scaling_min_val=2e-16,
    )
    weight_tensor = gen_finn_dt_tensor(DataType["INT4"], w_shape)
    b_conv.weight = torch.nn.Parameter(torch.from_numpy(weight_tensor).float())
    b_conv.eval()
    if QONNX_export:
        m_path = export_onnx_path
        BrevitasONNXManager.export(b_conv, ishape, m_path)
        qonnx_cleanup(m_path, out_file=m_path)
        model = ModelWrapper(m_path)
        model = model.transform(ConvertQONNXtoFINN())
        model.save(m_path)
    else:
        bo.export_finn_onnx(b_conv, ishape, export_onnx_path)
    model = ModelWrapper(export_onnx_path)
    model = model.transform(InferShapes())
    inp_tensor = np.random.uniform(low=-1.0, high=1.0, size=ishape).astype(np.float32)
    idict = {model.graph.input[0].name: inp_tensor}
    odict = oxe.execute_onnx(model, idict, True)
    produced = odict[model.graph.output[0].name]
    inp_tensor = torch.from_numpy(inp_tensor).float()
    expected = b_conv.forward(inp_tensor).detach().numpy()

    assert np.isclose(produced, expected, atol=1e-3).all()
    os.remove(export_onnx_path)
Ejemplo n.º 6
0
def test_generic_quant_avgpool_export_quant_input():
    IN_SIZE = (2, OUT_CH, IN_CH, IN_CH)
    inp = torch.randn(IN_SIZE)
    inp_quant = QuantIdentity(return_quant_tensor=True)
    model = QuantAvgPool2d(kernel_size=2, return_quant_tensor=False)
    inp_quant(inp)  # collect scale factors
    inp_quant.eval()
    model.eval()
    BrevitasONNXManager.export(
        model, input_t=inp_quant(inp), export_path='generic_quant_avgpool_quant_input.onnx')
Ejemplo n.º 7
0
def test_brevitas_act_export_relu(abits, max_val, scaling_impl_type,
                                  QONNX_export):
    min_val = -1.0
    ishape = (1, 15)

    b_act = QuantReLU(
        bit_width=abits,
        max_val=max_val,
        scaling_impl_type=scaling_impl_type,
        restrict_scaling_type=RestrictValueType.LOG_FP,
        quant_type=QuantType.INT,
    )
    if scaling_impl_type == ScalingImplType.PARAMETER:
        checkpoint = {
            "act_quant_proxy.fused_activation_quant_proxy.tensor_quant.\
scaling_impl.learned_value":
            torch.tensor(0.49).type(torch.FloatTensor)
        }
        b_act.load_state_dict(checkpoint)
    if QONNX_export:
        m_path = export_onnx_path
        BrevitasONNXManager.export(b_act, ishape, m_path)
        qonnx_cleanup(m_path, out_file=m_path)
        model = ModelWrapper(m_path)
        model = model.transform(ConvertQONNXtoFINN())
        model.save(m_path)
    else:
        bo.export_finn_onnx(b_act, ishape, export_onnx_path)
    model = ModelWrapper(export_onnx_path)
    model = model.transform(InferShapes())
    inp_tensor = np.random.uniform(low=min_val, high=max_val,
                                   size=ishape).astype(np.float32)
    idict = {model.graph.input[0].name: inp_tensor}
    odict = oxe.execute_onnx(model, idict, True)
    produced = odict[model.graph.output[0].name]
    inp_tensor = torch.from_numpy(inp_tensor).float()
    b_act.eval()
    expected = b_act.forward(inp_tensor).detach().numpy()
    if not np.isclose(produced, expected, atol=1e-3).all():
        print(abits, max_val, scaling_impl_type)
        print("scale: ",
              b_act.quant_act_scale().type(torch.FloatTensor).detach())
        if abits < 5:
            print(
                "thres:",
                ", ".join(["{:8.4f}".format(x)
                           for x in b_act.export_thres[0]]),
            )
        print("input:",
              ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]]))
        print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]]))
        print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]]))

    assert np.isclose(produced, expected, atol=1e-3).all()
    os.remove(export_onnx_path)
Ejemplo n.º 8
0
def test_brevitas_act_export_qhardtanh_nonscaled(abits, narrow_range, max_val,
                                                 QONNX_export):
    def get_quant_type(bit_width):
        if bit_width is None:
            return QuantType.FP
        elif bit_width == 1:
            return QuantType.BINARY
        else:
            return QuantType.INT

    act_quant_type = get_quant_type(abits)
    min_val = -1.0
    ishape = (1, 10)
    b_act = QuantHardTanh(
        bit_width=abits,
        quant_type=act_quant_type,
        max_val=max_val,
        min_val=min_val,
        restrict_scaling_type=RestrictValueType.LOG_FP,
        scaling_impl_type=ScalingImplType.CONST,
        narrow_range=narrow_range,
    )
    if QONNX_export:
        m_path = export_onnx_path
        BrevitasONNXManager.export(b_act, ishape, m_path)
        qonnx_cleanup(m_path, out_file=m_path)
        model = ModelWrapper(m_path)
        model = model.transform(ConvertQONNXtoFINN())
        model.save(m_path)
    else:
        bo.export_finn_onnx(b_act, ishape, export_onnx_path)
    model = ModelWrapper(export_onnx_path)
    model = model.transform(InferShapes())
    inp_tensor = np.random.uniform(low=min_val, high=max_val,
                                   size=ishape).astype(np.float32)
    idict = {model.graph.input[0].name: inp_tensor}
    odict = oxe.execute_onnx(model, idict, True)
    produced = odict[model.graph.output[0].name]
    inp_tensor = torch.from_numpy(inp_tensor).float()
    expected = b_act.forward(inp_tensor).detach().numpy()
    assert np.isclose(produced, expected, atol=1e-3).all()
    os.remove(export_onnx_path)
Ejemplo n.º 9
0
def test_generic_quant_avgpool_export():
    IN_SIZE = (2, OUT_CH, IN_CH, IN_CH)

    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.inp_quant = QuantIdentity(return_quant_tensor=True)
            self.pool = QuantAvgPool2d(kernel_size=2)

        def forward(self, x):
            return self.pool(self.inp_quant(x))

    inp = torch.randn(IN_SIZE)

    model = Model()
    model(inp)  # collect scale factors
    model.eval()
    BrevitasONNXManager.export(model,
                               input_t=inp,
                               export_path='./generic_quant_avgpool.onnx')
Ejemplo n.º 10
0
def test_brevitas_qlinear(
    bias, out_features, in_features, w_bits, i_dtype, QONNX_export
):
    i_shape = (1, in_features)
    w_shape = (out_features, in_features)
    b_linear = QuantLinear(
        out_features=out_features,
        in_features=in_features,
        bias=bias,
        bias_quant_type=QuantType.FP,
        weight_bit_width=w_bits,
        weight_quant_type=QuantType.INT,
        weight_scaling_per_output_channel=True,
    )
    weight_tensor_fp = np.random.uniform(low=-1.0, high=1.0, size=w_shape).astype(
        np.float32
    )
    b_linear.weight.data = torch.from_numpy(weight_tensor_fp)
    b_linear.eval()
    if QONNX_export:
        m_path = export_onnx_path
        BrevitasONNXManager.export(b_linear, i_shape, m_path)
        qonnx_cleanup(m_path, out_file=m_path)
        model = ModelWrapper(m_path)
        model = model.transform(ConvertQONNXtoFINN())
        model.save(m_path)
    else:
        bo.export_finn_onnx(b_linear, i_shape, export_onnx_path)
    model = ModelWrapper(export_onnx_path)
    model = model.transform(InferShapes())
    inp_tensor = gen_finn_dt_tensor(i_dtype, i_shape)
    idict = {model.graph.input[0].name: inp_tensor}
    odict = oxe.execute_onnx(model, idict, True)
    produced = odict[model.graph.output[0].name]
    inp_tensor = torch.from_numpy(inp_tensor).float()
    expected = b_linear.forward(inp_tensor).detach().numpy()

    assert np.isclose(produced, expected, atol=1e-3).all()
    os.remove(export_onnx_path)
Ejemplo n.º 11
0
def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, QONNX_export):
    if size == "LFC" and wbits == 2 and abits == 2:
        pytest.skip("No LFC-w2a2 present at the moment")
    if wbits > abits:
        pytest.skip("No wbits > abits cases at the moment")
    nname = "%s_%dW%dA_QONNX-%d" % (size, wbits, abits, QONNX_export)
    finn_onnx = export_onnx_path + "/%s.onnx" % nname
    fc = get_test_model_trained(size, wbits, abits)
    ishape = (1, 1, 28, 28)
    if QONNX_export:
        BrevitasONNXManager.export(fc, ishape, finn_onnx)
        qonnx_cleanup(finn_onnx, out_file=finn_onnx)
        model = ModelWrapper(finn_onnx)
        model = model.transform(ConvertQONNXtoFINN())
        model.save(finn_onnx)
    else:
        bo.export_finn_onnx(fc, ishape, finn_onnx)
    model = ModelWrapper(finn_onnx)
    model = model.transform(InferShapes())
    model = model.transform(FoldConstants())
    model = model.transform(RemoveStaticGraphInputs())
    assert len(model.graph.input) == 1
    assert len(model.graph.output) == 1
    # load one of the test vectors
    raw_i = get_data("finn.data", "onnx/mnist-conv/test_data_set_0/input_0.pb")
    input_tensor = onnx.load_tensor_from_string(raw_i)
    # run using FINN-based execution
    input_dict = {model.graph.input[0].name: nph.to_array(input_tensor)}
    output_dict = oxe.execute_onnx(model, input_dict)
    produced = output_dict[list(output_dict.keys())[0]]
    # run using PyTorch/Brevitas
    input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float()
    assert input_tensor.shape == (1, 1, 28, 28)
    # do forward pass in PyTorch/Brevitas
    expected = fc.forward(input_tensor).detach().numpy()
    assert np.isclose(produced, expected, atol=1e-3).all()
Ejemplo n.º 12
0
def test_brevitas_avg_pool_export(
    kernel_size,
    stride,
    signed,
    bit_width,
    input_bit_width,
    channels,
    idim,
    QONNX_export,
):
    export_onnx_path = base_export_onnx_path.replace(
        ".onnx", f"test_QONNX-{QONNX_export}.onnx"
    )
    quant_avgpool = QuantAvgPool2d(
        kernel_size=kernel_size,
        stride=stride,
        bit_width=bit_width,
        return_quant_tensor=False,
    )
    quant_avgpool.eval()

    # determine input
    prefix = "INT" if signed else "UINT"
    dt_name = prefix + str(input_bit_width)
    dtype = DataType[dt_name]
    input_shape = (1, channels, idim, idim)
    input_array = gen_finn_dt_tensor(dtype, input_shape)
    # Brevitas QuantAvgPool layers need QuantTensors to export correctly
    # which requires setting up a QuantTensor instance with the scale
    # factor, zero point, bitwidth and signedness
    scale_array = np.ones((1, channels, 1, 1)).astype(np.float32)
    scale_array *= 0.5
    input_tensor = torch.from_numpy(input_array * scale_array).float()
    scale_tensor = torch.from_numpy(scale_array).float()
    zp = torch.tensor(0.0)
    input_quant_tensor = QuantTensor(
        input_tensor, scale_tensor, zp, input_bit_width, signed, training=False
    )

    # export
    if QONNX_export:
        BrevitasONNXManager.export(
            quant_avgpool,
            export_path=export_onnx_path,
            input_t=input_quant_tensor,
        )
        model = ModelWrapper(export_onnx_path)

        # Statically set the additional inputs generated by the BrevitasONNXManager
        model.graph.input.remove(model.graph.input[3])
        model.graph.input.remove(model.graph.input[2])
        model.graph.input.remove(model.graph.input[1])
        model.set_initializer("1", scale_array)
        model.set_initializer("2", np.array(0.0).astype(np.float32))
        model.set_initializer("3", np.array(input_bit_width).astype(np.float32))
        model.save(export_onnx_path)

        qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
        model = ModelWrapper(export_onnx_path)
        model = model.transform(ConvertQONNXtoFINN())
        model.save(export_onnx_path)
    else:
        FINNManager.export(
            quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor
        )
    model = ModelWrapper(export_onnx_path)
    model = model.transform(InferShapes())
    model = model.transform(InferDataTypes())

    # reference brevitas output
    ref_output_array = quant_avgpool(input_quant_tensor).detach().numpy()
    # finn output
    if QONNX_export:
        # Manually apply the Quant tensor scaling for QONNX
        idict = {model.graph.input[0].name: input_array * scale_array}
    else:
        idict = {model.graph.input[0].name: input_array}
    odict = oxe.execute_onnx(model, idict, True)
    finn_output = odict[model.graph.output[0].name]
    # compare outputs
    assert np.isclose(ref_output_array, finn_output).all()
    # cleanup
    # assert False
    os.remove(export_onnx_path)
Ejemplo n.º 13
0
def test_end2end_cybsec_mlp_export(QONNX_export):
    assets_dir = pk.resource_filename("finn.qnn-data", "cybsec-mlp/")
    # load up trained net in Brevitas
    input_size = 593
    hidden1 = 64
    hidden2 = 64
    hidden3 = 64
    weight_bit_width = 2
    act_bit_width = 2
    num_classes = 1
    model = nn.Sequential(
        QuantLinear(input_size,
                    hidden1,
                    bias=True,
                    weight_bit_width=weight_bit_width),
        nn.BatchNorm1d(hidden1),
        nn.Dropout(0.5),
        QuantReLU(bit_width=act_bit_width),
        QuantLinear(hidden1,
                    hidden2,
                    bias=True,
                    weight_bit_width=weight_bit_width),
        nn.BatchNorm1d(hidden2),
        nn.Dropout(0.5),
        QuantReLU(bit_width=act_bit_width),
        QuantLinear(hidden2,
                    hidden3,
                    bias=True,
                    weight_bit_width=weight_bit_width),
        nn.BatchNorm1d(hidden3),
        nn.Dropout(0.5),
        QuantReLU(bit_width=act_bit_width),
        QuantLinear(hidden3,
                    num_classes,
                    bias=True,
                    weight_bit_width=weight_bit_width),
    )
    trained_state_dict = torch.load(assets_dir +
                                    "/state_dict.pth")["models_state_dict"][0]
    model.load_state_dict(trained_state_dict, strict=False)
    W_orig = model[0].weight.data.detach().numpy()
    # pad the second (593-sized) dimensions with 7 zeroes at the end
    W_new = np.pad(W_orig, [(0, 0), (0, 7)])
    model[0].weight.data = torch.from_numpy(W_new)
    model_for_export = CybSecMLPForExport(model)
    export_onnx_path = get_checkpoint_name("export", QONNX_export)
    input_shape = (1, 600)
    # create a QuantTensor instance to mark the input as bipolar during export
    input_a = np.random.randint(0, 1, size=input_shape).astype(np.float32)
    input_a = 2 * input_a - 1
    scale = 1.0
    input_t = torch.from_numpy(input_a * scale)
    input_qt = QuantTensor(input_t,
                           scale=torch.tensor(scale),
                           bit_width=torch.tensor(1.0),
                           signed=True)

    if QONNX_export:
        # With the BrevitasONNXManager we need to manually set
        # the FINN DataType at the input
        BrevitasONNXManager.export(model_for_export,
                                   input_shape,
                                   export_path=export_onnx_path)
        model = ModelWrapper(export_onnx_path)
        model.set_tensor_datatype(model.graph.input[0].name,
                                  DataType["BIPOLAR"])
        model.save(export_onnx_path)
        qonnx_cleanup(export_onnx_path, out_file=export_onnx_path)
        model = ModelWrapper(export_onnx_path)
        model = model.transform(ConvertQONNXtoFINN())
        model.save(export_onnx_path)
    else:
        bo.export_finn_onnx(model_for_export,
                            export_path=export_onnx_path,
                            input_t=input_qt)
    assert os.path.isfile(export_onnx_path)
    # fix input datatype
    finn_model = ModelWrapper(export_onnx_path)
    finnonnx_in_tensor_name = finn_model.graph.input[0].name
    assert tuple(finn_model.get_tensor_shape(finnonnx_in_tensor_name)) == (1,
                                                                           600)
    # verify a few exported ops
    if QONNX_export:
        # The first "Mul" node doesn't exist in the QONNX export,
        # because the QuantTensor scale is not exported.
        # However, this node would have been unity scale anyways and
        # the models are still equivalent.
        assert finn_model.graph.node[0].op_type == "Add"
        assert finn_model.graph.node[1].op_type == "Div"
        assert finn_model.graph.node[2].op_type == "MatMul"
        assert finn_model.graph.node[-1].op_type == "MultiThreshold"
    else:
        assert finn_model.graph.node[0].op_type == "Mul"
        assert finn_model.get_initializer(
            finn_model.graph.node[0].input[1]) == 1.0
        assert finn_model.graph.node[1].op_type == "Add"
        assert finn_model.graph.node[2].op_type == "Div"
        assert finn_model.graph.node[3].op_type == "MatMul"
        assert finn_model.graph.node[-1].op_type == "MultiThreshold"
    # verify datatypes on some tensors
    assert (finn_model.get_tensor_datatype(finnonnx_in_tensor_name) ==
            DataType["BIPOLAR"])
    first_matmul_w_name = finn_model.get_nodes_by_op_type("MatMul")[0].input[1]
    assert finn_model.get_tensor_datatype(
        first_matmul_w_name) == DataType["INT2"]
Ejemplo n.º 14
0
def test_brevitas_debug(QONNX_export, QONNX_FINN_conversion):
    if (not QONNX_export) and QONNX_FINN_conversion:
        pytest.skip(
            "This test configuration is not valid and is thus skipped.")
    finn_onnx = "test_brevitas_debug.onnx"
    fc = get_test_model_trained("TFC", 2, 2)
    ishape = (1, 1, 28, 28)
    if QONNX_export:
        dbg_hook = bo.enable_debug(fc, proxy_level=True)
        BrevitasONNXManager.export(fc, ishape, finn_onnx)
        # DebugMarkers have the brevitas.onnx domain, so that needs adjusting
        model = ModelWrapper(finn_onnx)
        dbg_nodes = model.get_nodes_by_op_type("DebugMarker")
        for dbg_node in dbg_nodes:
            dbg_node.domain = "finn.custom_op.general"
        model.save(finn_onnx)
        qonnx_cleanup(finn_onnx, out_file=finn_onnx)
        if QONNX_FINN_conversion:
            model = ModelWrapper(finn_onnx)
            model = model.transform(ConvertQONNXtoFINN())
            model.save(finn_onnx)
    else:
        dbg_hook = bo.enable_debug(fc)
        bo.export_finn_onnx(fc, ishape, finn_onnx)
        model = ModelWrapper(finn_onnx)
        # DebugMarkers have the brevitas.onnx domain, so that needs adjusting
        # ToDo: We should probably have transformation pass, which does this
        #  domain conversion for us?
        dbg_nodes = model.get_nodes_by_op_type("DebugMarker")
        for dbg_node in dbg_nodes:
            dbg_node.domain = "finn.custom_op.general"
        model = model.transform(InferShapes())
        model = model.transform(FoldConstants())
        model = model.transform(RemoveStaticGraphInputs())
        model.save(finn_onnx)
    model = ModelWrapper(finn_onnx)
    assert len(model.graph.input) == 1
    assert len(model.graph.output) == 1
    # load one of the test vectors
    raw_i = get_data("finn.data", "onnx/mnist-conv/test_data_set_0/input_0.pb")
    input_tensor = onnx.load_tensor_from_string(raw_i)
    # run using FINN-based execution
    input_dict = {model.graph.input[0].name: nph.to_array(input_tensor)}
    output_dict = oxe.execute_onnx(model,
                                   input_dict,
                                   return_full_exec_context=True)
    produced = output_dict[model.graph.output[0].name]
    # run using PyTorch/Brevitas
    input_tensor = torch.from_numpy(nph.to_array(input_tensor)).float()
    assert input_tensor.shape == (1, 1, 28, 28)
    # do forward pass in PyTorch/Brevitas
    expected = fc.forward(input_tensor).detach().numpy()
    assert np.isclose(produced, expected, atol=1e-3).all()
    # check all tensors at debug markers
    names_brevitas = set(dbg_hook.values.keys())
    names_finn = set(output_dict.keys())
    names_common = names_brevitas.intersection(names_finn)
    # The different exports return debug markers in different numbers and places
    print(len(names_common))
    if QONNX_export and not QONNX_FINN_conversion:
        assert len(names_common) == 12
    elif QONNX_export and QONNX_FINN_conversion:
        assert len(names_common) == 8
    else:
        assert len(names_common) == 16
    for dbg_name in names_common:
        if QONNX_export:
            tensor_pytorch = dbg_hook.values[dbg_name].value.detach().numpy()
        else:
            tensor_pytorch = dbg_hook.values[dbg_name].detach().numpy()
        tensor_finn = output_dict[dbg_name]
        assert np.isclose(tensor_finn, tensor_pytorch, atol=1e-5).all()
    os.remove(finn_onnx)