示例#1
0
 def custom_converter(node: NodeContext):
     nonlocal invoked
     invoked = True
     a = node.get_input(0)
     b = node.get_input(1)
     add = ops.add(a, b)
     return [add.output(0)]
示例#2
0
    def custom_converter(node: NodeContext):
        nonlocal invoked
        invoked = True

        def check_attribute(context, name, expected_value, dtype):
            attribute = context.get_attribute(name, dtype=dtype)
            if isinstance(attribute, list):
                assert type(attribute[0]) == dtype
            else:
                assert type(attribute) == dtype
            assert attribute == expected_value

        check_attribute(node, "attribute_i32", 10, float)
        check_attribute(node, "attribute_i64", 10, float)
        check_attribute(node, "attribute_str", "string", np.str)
        check_attribute(node, "attribute_f32", 10, int)
        check_attribute(node, "attribute_f64", 10, int)
        check_attribute(node, "attribute_bool", True, bool)
        check_attribute(node, "attribute_type", Type.i32, Type)

        check_attribute(node, "attribute_list_i32", [1., 2., 3.], float)
        check_attribute(node, "attribute_list_i64", [1., 2., 3.], float)
        check_attribute(node, "attribute_list_str", ["a", "b", "c"], np.str)
        check_attribute(node, "attribute_list_f32", [1, 2, 3], int)
        check_attribute(node, "attribute_list_f64", [1, 2, 3], int)
        check_attribute(node, "attribute_list_bool", [True, False, True], bool)
        check_attribute(node, "attribute_list_type", [Type.i32, Type.f32],
                        Type)

        a = node.get_input(0)
        b = node.get_input(1)
        add = ops.add(a, b)
        return [add.output(0)]
示例#3
0
    def custom_converter(node: NodeContext):
        nonlocal invoked
        invoked = True

        def check_attribute(context, name, default_value):
            assert not context.has_attribute(name)
            attribute = context.get_attribute(name, default_value)
            assert type(attribute) == type(default_value)
            if isinstance(attribute, np.ndarray):
                assert np.all(attribute == default_value)
            else:
                assert attribute == default_value

        check_attribute(node, "attribute_i32", np.int32(5))
        check_attribute(node, "attribute_i64", np.int64(5))
        check_attribute(node, "attribute_str", "abc")
        check_attribute(node, "attribute_f32", np.float32(5))
        check_attribute(node, "attribute_f64", np.float64(5))
        check_attribute(node, "attribute_bool", np.bool(False))
        check_attribute(node, "attribute_type", onnx.TensorProto.FLOAT)

        check_attribute(node, "attribute_list_i32", np.array([4, 5, 6], dtype=np.int32))
        check_attribute(node, "attribute_list_i64", np.array([4, 5, 6], dtype=np.int64))
        check_attribute(node, "attribute_list_str", np.array(["d", "e", "f"], dtype=np.str))
        check_attribute(node, "attribute_list_f32", np.array([4, 5, 6], dtype=np.float))
        check_attribute(node, "attribute_list_f64", np.array([4, 5, 6], dtype=np.float64))
        check_attribute(node, "attribute_list_bool", np.array([True, False, True], dtype=np.bool))
        check_attribute(node, "attribute_list_type", np.array([onnx.TensorProto.INT32,
                                                               onnx.TensorProto.FLOAT]))

        a = node.get_input(0)
        b = node.get_input(1)
        add = ops.add(a, b)
        return [add.output(0)]
示例#4
0
    def custom_converter(node: NodeContext):
        nonlocal invoked
        invoked = True

        def check_attribute(context, name, expected_type, expected_value):
            assert context.has_attribute(name)
            attribute = context.get_attribute(name)
            assert type(attribute) == expected_type
            assert attribute == expected_value

        check_attribute(node, "attribute_i32", int, 10)
        check_attribute(node, "attribute_i64", int, 10)
        check_attribute(node, "attribute_str", str, "string")
        check_attribute(node, "attribute_f32", float, 10.)
        check_attribute(node, "attribute_f64", float, 10.)
        check_attribute(node, "attribute_bool", int, 1)
        check_attribute(node, "attribute_type", int, 6)

        check_attribute(node, "attribute_list_i32", list, [1, 2, 3])
        check_attribute(node, "attribute_list_i64", list, [1, 2, 3])
        check_attribute(node, "attribute_list_str", list, ["a", "b", "c"])
        check_attribute(node, "attribute_list_f32", list, [1., 2., 3.])
        check_attribute(node, "attribute_list_f64", list, [1., 2., 3.])
        check_attribute(node, "attribute_list_bool", list, [1, 0, 1])
        check_attribute(node, "attribute_list_type", list, [6, 1])

        a = node.get_input(0)
        b = node.get_input(1)
        add = ops.add(a, b)
        return [add.output(0)]
示例#5
0
 def custom_converter(node: NodeContext):
     node.get_input(0)
     node.get_attribute("alpha")