Beispiel #1
0
def test_custom_op_rel_infer_exception():
    """Tests infer type for custom_op"""
    def custom_log1_rel(arg_types, attrs):
        assert len(arg_types) == 2, "type relation arg number mismatch!"
        return None

    op_name = "custom_log2"
    _op.register(op_name, r"code(cal log of a tensor.)code")
    _op.get(op_name).set_num_inputs(1)
    _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
    _op.get(op_name).set_attrs_type_key("DictAttrs")
    # call customized relation functions
    _op.get(op_name).add_type_rel("custom_log2", custom_log1_rel)
    _op.get(op_name).set_support_level(1)
    _op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
    _op.register_stateful(op_name, False)

    def clog(x):
        return relay.Call(_op.get(op_name), [x])

    tp = relay.TensorType((10, 10), "float32")
    x = relay.var("x", tp)
    sb = relay.ScopeBuilder()
    t1 = sb.let("t1", clog(x))
    t2 = sb.let("t2", relay.add(t1, x))
    sb.ret(t2)
    f = relay.Function([x], sb.get())
    with pytest.raises(tvm.error.TVMError) as cm:
        fchecked = infer_expr(f)
        assert "type relation arg number mismatch" in str(cm.execption)
Beispiel #2
0
def test_custom_op_rel_infer():
    """Tests infer type for custom_op"""
    def custom_log1_rel(arg_types, attrs):
        assert len(arg_types) == 1, "type relation arg number mismatch!"
        if attrs:
            assert isinstance(attrs, DictAttrs)
        inputa_type = arg_types[0]
        return relay.TensorType(inputa_type.shape, inputa_type.dtype)

    op_name = "custom_log1"
    _op.register(op_name, r"code(cal log of a tensor.)code")
    _op.get(op_name).set_num_inputs(1)
    _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
    _op.get(op_name).set_attrs_type_key("DictAttrs")
    # call customized relation functions
    _op.get(op_name).add_type_rel("custom_log1", custom_log1_rel)
    _op.get(op_name).set_support_level(1)
    _op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
    _op.register_stateful(op_name, False)

    def clog(x):
        return relay.Call(_op.get(op_name), [x])

    tp = relay.TensorType((10, 10), "float32")
    x = relay.var("x", tp)
    sb = relay.ScopeBuilder()
    t1 = sb.let("t1", clog(x))
    t2 = sb.let("t2", relay.add(t1, x))
    sb.ret(t2)
    f = relay.Function([x], sb.get())
    fchecked = infer_expr(f)
    assert fchecked.checked_type == relay.FuncType([tp], tp)
Beispiel #3
0
def test_op_register():
    """Tests register_op functionality."""
    op_name = "custom_op"

    _op.register(op_name, r"code(Add two tensor with inner broadcasting.)code")
    _op.get(op_name).set_num_inputs(2)
    _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
    _op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.")
    # call default relation functions
    _op.get(op_name).add_type_rel("Identity")
    _op.get(op_name).set_support_level(1)
    _op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
    _op.register_stateful(op_name, False)

    assert _op.get(op_name).name == op_name
    assert _op.get(op_name).num_inputs == 2
    assert _op.get(op_name).get_attr("TOpPattern") == _op.OpPattern.ELEMWISE
    assert _op.get(op_name).get_attr("TOpIsStateful") == False
Beispiel #4
0
def test_custom_add_broadcast_op():
    """Tests infer type for broadcast custom_op"""
    op_name = "custom_broadcast_add"
    _op.register(op_name, r"code(Add two tensor with inner broadcasting.)code")
    _op.get(op_name).set_num_inputs(2)
    _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
    _op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.")
    # call default relation functions
    _op.get(op_name).add_type_rel("Broadcast")
    _op.get(op_name).set_support_level(1)
    _op.register_stateful(op_name, False)

    def broadcast_add(x, y):
        return relay.Call(_op.get(op_name), [x, y])

    x = relay.var("x", shape=(10, 4))
    y = relay.var("y", shape=(5, 10, 1))
    z = broadcast_add(x, y)
    func = relay.Function([x, y], z)
    t1 = relay.TensorType((10, 4), "float32")
    t2 = relay.TensorType((5, 10, 1), "float32")
    t3 = relay.TensorType((5, 10, 4), "float32")
    expected_ty = relay.FuncType([t1, t2], t3)
    assert_has_type(func, expected_ty)
Beispiel #5
0
def test_custom_op_infer():
    """Tests infer type for custom_op"""
    op_name = "custom_log"
    _op.register(op_name, r"code(cal log of a tensor.)code")
    _op.get(op_name).set_num_inputs(1)
    _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
    # call default relation functions
    _op.get(op_name).add_type_rel("Identity")
    _op.get(op_name).set_support_level(1)
    _op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
    _op.register_stateful(op_name, False)

    def clog(x):
        return relay.Call(_op.get(op_name), [x])

    tp = relay.TensorType((10, 10), "float32")
    x = relay.var("x", tp)
    sb = relay.ScopeBuilder()
    t1 = sb.let("t1", clog(x))
    t2 = sb.let("t2", relay.add(t1, x))
    sb.ret(t2)
    f = relay.Function([x], sb.get())
    fchecked = infer_expr(f)
    assert fchecked.checked_type == relay.FuncType([tp], tp)
Beispiel #6
0
def register_qnn_legalize(op_name, legal_op=None, level=10):
    """Register legal transformation function for a QNN op

    Parameters
    ----------
    op_name : str
        The name of the operator

    legal_op: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
        The function for transforming an expr to another expr.

    level : int
        The priority level
    """
    return register(op_name, "FTVMQnnLegalize", legal_op, level)
Beispiel #7
0
def test_repeat_register():
    op_name = "custom_log3"
    _op.register(op_name, r"code(cal log of a tensor.)code")
    with pytest.raises(tvm.error.TVMError) as cm:
        _op.register(op_name)
        assert "Operator custom_log3 is registered before" in str(cm.execption)