Ejemplo n.º 1
0
    def custom_converter(node: NodeContext):
        nonlocal invoked
        invoked = True

        def check_attribute(context, name, expected_type, expected_value):
            assert context.has_attribute(name)
            attribute = context.get_attribute(name)
            assert type(attribute) == expected_type
            assert attribute == expected_value

        check_attribute(node, "attribute_i32", int, 10)
        check_attribute(node, "attribute_i64", int, 10)
        check_attribute(node, "attribute_str", str, "string")
        check_attribute(node, "attribute_f32", float, 10.)
        check_attribute(node, "attribute_f64", float, 10.)
        check_attribute(node, "attribute_bool", int, 1)
        check_attribute(node, "attribute_type", int, 6)

        check_attribute(node, "attribute_list_i32", list, [1, 2, 3])
        check_attribute(node, "attribute_list_i64", list, [1, 2, 3])
        check_attribute(node, "attribute_list_str", list, ["a", "b", "c"])
        check_attribute(node, "attribute_list_f32", list, [1., 2., 3.])
        check_attribute(node, "attribute_list_f64", list, [1., 2., 3.])
        check_attribute(node, "attribute_list_bool", list, [1, 0, 1])
        check_attribute(node, "attribute_list_type", list, [6, 1])

        a = node.get_input(0)
        b = node.get_input(1)
        add = ops.add(a, b)
        return [add.output(0)]
Ejemplo n.º 2
0
    def custom_converter(node: NodeContext):
        nonlocal invoked
        invoked = True

        def check_attribute(context, name, default_value):
            assert not context.has_attribute(name)
            attribute = context.get_attribute(name, default_value)
            assert type(attribute) == type(default_value)
            if isinstance(attribute, np.ndarray):
                assert np.all(attribute == default_value)
            else:
                assert attribute == default_value

        check_attribute(node, "attribute_i32", np.int32(5))
        check_attribute(node, "attribute_i64", np.int64(5))
        check_attribute(node, "attribute_str", "abc")
        check_attribute(node, "attribute_f32", np.float32(5))
        check_attribute(node, "attribute_f64", np.float64(5))
        check_attribute(node, "attribute_bool", np.bool(False))
        check_attribute(node, "attribute_type", onnx.TensorProto.FLOAT)

        check_attribute(node, "attribute_list_i32", np.array([4, 5, 6], dtype=np.int32))
        check_attribute(node, "attribute_list_i64", np.array([4, 5, 6], dtype=np.int64))
        check_attribute(node, "attribute_list_str", np.array(["d", "e", "f"], dtype=np.str))
        check_attribute(node, "attribute_list_f32", np.array([4, 5, 6], dtype=np.float))
        check_attribute(node, "attribute_list_f64", np.array([4, 5, 6], dtype=np.float64))
        check_attribute(node, "attribute_list_bool", np.array([True, False, True], dtype=np.bool))
        check_attribute(node, "attribute_list_type", np.array([onnx.TensorProto.INT32,
                                                               onnx.TensorProto.FLOAT]))

        a = node.get_input(0)
        b = node.get_input(1)
        add = ops.add(a, b)
        return [add.output(0)]
Ejemplo n.º 3
0
 def custom_converter(node: NodeContext):
     node.get_input(0)
     node.get_attribute("alpha")