def test_custom_op_rel_infer_exception(): """Tests infer type for custom_op""" def custom_log1_rel(arg_types, attrs): assert len(arg_types) == 2, "type relation arg number mismatch!" return None op_name = "custom_log2" _op.register(op_name, r"code(cal log of a tensor.)code") _op.get(op_name).set_num_inputs(1) _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") _op.get(op_name).set_attrs_type_key("DictAttrs") # call customized relation functions _op.get(op_name).add_type_rel("custom_log2", custom_log1_rel) _op.get(op_name).set_support_level(1) _op.register_pattern(op_name, _op.OpPattern.ELEMWISE) _op.register_stateful(op_name, False) def clog(x): return relay.Call(_op.get(op_name), [x]) tp = relay.TensorType((10, 10), "float32") x = relay.var("x", tp) sb = relay.ScopeBuilder() t1 = sb.let("t1", clog(x)) t2 = sb.let("t2", relay.add(t1, x)) sb.ret(t2) f = relay.Function([x], sb.get()) with pytest.raises(tvm.error.TVMError) as cm: fchecked = infer_expr(f) assert "type relation arg number mismatch" in str(cm.execption)
def test_custom_op_rel_infer(): """Tests infer type for custom_op""" def custom_log1_rel(arg_types, attrs): assert len(arg_types) == 1, "type relation arg number mismatch!" if attrs: assert isinstance(attrs, DictAttrs) inputa_type = arg_types[0] return relay.TensorType(inputa_type.shape, inputa_type.dtype) op_name = "custom_log1" _op.register(op_name, r"code(cal log of a tensor.)code") _op.get(op_name).set_num_inputs(1) _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") _op.get(op_name).set_attrs_type_key("DictAttrs") # call customized relation functions _op.get(op_name).add_type_rel("custom_log1", custom_log1_rel) _op.get(op_name).set_support_level(1) _op.register_pattern(op_name, _op.OpPattern.ELEMWISE) _op.register_stateful(op_name, False) def clog(x): return relay.Call(_op.get(op_name), [x]) tp = relay.TensorType((10, 10), "float32") x = relay.var("x", tp) sb = relay.ScopeBuilder() t1 = sb.let("t1", clog(x)) t2 = sb.let("t2", relay.add(t1, x)) sb.ret(t2) f = relay.Function([x], sb.get()) fchecked = infer_expr(f) assert fchecked.checked_type == relay.FuncType([tp], tp)
def test_op_register(): """Tests register_op functionality.""" op_name = "custom_op" _op.register(op_name, r"code(Add two tensor with inner broadcasting.)code") _op.get(op_name).set_num_inputs(2) _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") _op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.") # call default relation functions _op.get(op_name).add_type_rel("Identity") _op.get(op_name).set_support_level(1) _op.register_pattern(op_name, _op.OpPattern.ELEMWISE) _op.register_stateful(op_name, False) assert _op.get(op_name).name == op_name assert _op.get(op_name).num_inputs == 2 assert _op.get(op_name).get_attr("TOpPattern") == _op.OpPattern.ELEMWISE assert _op.get(op_name).get_attr("TOpIsStateful") == False
def test_custom_add_broadcast_op(): """Tests infer type for broadcast custom_op""" op_name = "custom_broadcast_add" _op.register(op_name, r"code(Add two tensor with inner broadcasting.)code") _op.get(op_name).set_num_inputs(2) _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") _op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.") # call default relation functions _op.get(op_name).add_type_rel("Broadcast") _op.get(op_name).set_support_level(1) _op.register_stateful(op_name, False) def broadcast_add(x, y): return relay.Call(_op.get(op_name), [x, y]) x = relay.var("x", shape=(10, 4)) y = relay.var("y", shape=(5, 10, 1)) z = broadcast_add(x, y) func = relay.Function([x, y], z) t1 = relay.TensorType((10, 4), "float32") t2 = relay.TensorType((5, 10, 1), "float32") t3 = relay.TensorType((5, 10, 4), "float32") expected_ty = relay.FuncType([t1, t2], t3) assert_has_type(func, expected_ty)
def test_custom_op_infer(): """Tests infer type for custom_op""" op_name = "custom_log" _op.register(op_name, r"code(cal log of a tensor.)code") _op.get(op_name).set_num_inputs(1) _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.") # call default relation functions _op.get(op_name).add_type_rel("Identity") _op.get(op_name).set_support_level(1) _op.register_pattern(op_name, _op.OpPattern.ELEMWISE) _op.register_stateful(op_name, False) def clog(x): return relay.Call(_op.get(op_name), [x]) tp = relay.TensorType((10, 10), "float32") x = relay.var("x", tp) sb = relay.ScopeBuilder() t1 = sb.let("t1", clog(x)) t2 = sb.let("t2", relay.add(t1, x)) sb.ret(t2) f = relay.Function([x], sb.get()) fchecked = infer_expr(f) assert fchecked.checked_type == relay.FuncType([tp], tp)
def register_qnn_legalize(op_name, legal_op=None, level=10): """Register legal transformation function for a QNN op Parameters ---------- op_name : str The name of the operator legal_op: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr The function for transforming an expr to another expr. level : int The priority level """ return register(op_name, "FTVMQnnLegalize", legal_op, level)
def test_repeat_register(): op_name = "custom_log3" _op.register(op_name, r"code(cal log of a tensor.)code") with pytest.raises(tvm.error.TVMError) as cm: _op.register(op_name) assert "Operator custom_log3 is registered before" in str(cm.execption)