def custom_converter(node: NodeContext): nonlocal invoked invoked = True a = node.get_input(0) b = node.get_input(1) add = ops.add(a, b) return [add.output(0)]
def custom_converter(node: NodeContext): nonlocal invoked invoked = True def check_attribute(context, name, expected_value, dtype): attribute = context.get_attribute(name, dtype=dtype) if isinstance(attribute, list): assert type(attribute[0]) == dtype else: assert type(attribute) == dtype assert attribute == expected_value check_attribute(node, "attribute_i32", 10, float) check_attribute(node, "attribute_i64", 10, float) check_attribute(node, "attribute_str", "string", np.str) check_attribute(node, "attribute_f32", 10, int) check_attribute(node, "attribute_f64", 10, int) check_attribute(node, "attribute_bool", True, bool) check_attribute(node, "attribute_type", Type.i32, Type) check_attribute(node, "attribute_list_i32", [1., 2., 3.], float) check_attribute(node, "attribute_list_i64", [1., 2., 3.], float) check_attribute(node, "attribute_list_str", ["a", "b", "c"], np.str) check_attribute(node, "attribute_list_f32", [1, 2, 3], int) check_attribute(node, "attribute_list_f64", [1, 2, 3], int) check_attribute(node, "attribute_list_bool", [True, False, True], bool) check_attribute(node, "attribute_list_type", [Type.i32, Type.f32], Type) a = node.get_input(0) b = node.get_input(1) add = ops.add(a, b) return [add.output(0)]
def custom_converter(node: NodeContext): nonlocal invoked invoked = True def check_attribute(context, name, default_value): assert not context.has_attribute(name) attribute = context.get_attribute(name, default_value) assert type(attribute) == type(default_value) if isinstance(attribute, np.ndarray): assert np.all(attribute == default_value) else: assert attribute == default_value check_attribute(node, "attribute_i32", np.int32(5)) check_attribute(node, "attribute_i64", np.int64(5)) check_attribute(node, "attribute_str", "abc") check_attribute(node, "attribute_f32", np.float32(5)) check_attribute(node, "attribute_f64", np.float64(5)) check_attribute(node, "attribute_bool", np.bool(False)) check_attribute(node, "attribute_type", onnx.TensorProto.FLOAT) check_attribute(node, "attribute_list_i32", np.array([4, 5, 6], dtype=np.int32)) check_attribute(node, "attribute_list_i64", np.array([4, 5, 6], dtype=np.int64)) check_attribute(node, "attribute_list_str", np.array(["d", "e", "f"], dtype=np.str)) check_attribute(node, "attribute_list_f32", np.array([4, 5, 6], dtype=np.float)) check_attribute(node, "attribute_list_f64", np.array([4, 5, 6], dtype=np.float64)) check_attribute(node, "attribute_list_bool", np.array([True, False, True], dtype=np.bool)) check_attribute(node, "attribute_list_type", np.array([onnx.TensorProto.INT32, onnx.TensorProto.FLOAT])) a = node.get_input(0) b = node.get_input(1) add = ops.add(a, b) return [add.output(0)]
def custom_converter(node: NodeContext): nonlocal invoked invoked = True def check_attribute(context, name, expected_type, expected_value): assert context.has_attribute(name) attribute = context.get_attribute(name) assert type(attribute) == expected_type assert attribute == expected_value check_attribute(node, "attribute_i32", int, 10) check_attribute(node, "attribute_i64", int, 10) check_attribute(node, "attribute_str", str, "string") check_attribute(node, "attribute_f32", float, 10.) check_attribute(node, "attribute_f64", float, 10.) check_attribute(node, "attribute_bool", int, 1) check_attribute(node, "attribute_type", int, 6) check_attribute(node, "attribute_list_i32", list, [1, 2, 3]) check_attribute(node, "attribute_list_i64", list, [1, 2, 3]) check_attribute(node, "attribute_list_str", list, ["a", "b", "c"]) check_attribute(node, "attribute_list_f32", list, [1., 2., 3.]) check_attribute(node, "attribute_list_f64", list, [1., 2., 3.]) check_attribute(node, "attribute_list_bool", list, [1, 0, 1]) check_attribute(node, "attribute_list_type", list, [6, 1]) a = node.get_input(0) b = node.get_input(1) add = ops.add(a, b) return [add.output(0)]
def custom_converter(node: NodeContext): node.get_input(0) node.get_attribute("alpha")