Example #1
0
def test_onnx_conversion_extension_attribute_with_default_value():
    skip_if_onnx_frontend_is_disabled()

    # use specific (openvino.frontend.onnx) import here
    from openvino.frontend.onnx import ConversionExtension
    from openvino.frontend import NodeContext
    import openvino.runtime.opset8 as ops

    # use the model without attributes
    fe = fem.load_by_model(onnx_model_filename)
    assert fe
    assert fe.get_name() == "onnx"

    invoked = False

    def custom_converter(node: NodeContext):
        nonlocal invoked
        invoked = True

        def check_attribute(context, name, default_value):
            assert not context.has_attribute(name)
            attribute = context.get_attribute(name, default_value)
            assert type(attribute) == type(default_value)
            if isinstance(attribute, np.ndarray):
                assert np.all(attribute == default_value)
            else:
                assert attribute == default_value

        check_attribute(node, "attribute_i32", np.int32(5))
        check_attribute(node, "attribute_i64", np.int64(5))
        check_attribute(node, "attribute_str", "abc")
        check_attribute(node, "attribute_f32", np.float32(5))
        check_attribute(node, "attribute_f64", np.float64(5))
        check_attribute(node, "attribute_bool", np.bool(False))
        check_attribute(node, "attribute_type", onnx.TensorProto.FLOAT)

        check_attribute(node, "attribute_list_i32",
                        np.array([4, 5, 6], dtype=np.int32))
        check_attribute(node, "attribute_list_i64",
                        np.array([4, 5, 6], dtype=np.int64))
        check_attribute(node, "attribute_list_str",
                        np.array(["d", "e", "f"], dtype=np.str))
        check_attribute(node, "attribute_list_f32",
                        np.array([4, 5, 6], dtype=np.float))
        check_attribute(node, "attribute_list_f64",
                        np.array([4, 5, 6], dtype=np.float64))
        check_attribute(node, "attribute_list_bool",
                        np.array([True, False, True], dtype=np.bool))
        check_attribute(
            node, "attribute_list_type",
            np.array([onnx.TensorProto.INT32, onnx.TensorProto.FLOAT]))

        a = node.get_input(0)
        b = node.get_input(1)
        add = ops.add(a, b)
        return [add.output(0)]

    fe.add_extension(ConversionExtension("Add", custom_converter))
    input_model = fe.load(onnx_model_filename)
    assert input_model
    model = fe.convert(input_model)
    assert model
    assert invoked