def test_eta_expand_constructor():
    mod = relay.fromtext(r"""
        v0.0.4
        type List[A] {
            Cons(A, List[A]),
            Nil,
        }
        def @main[A]() -> (fn(A, List[A]) -> List[A]) {
            Cons
        }
    """)
    seq = _transform.Sequential(
        [_transform.EtaExpand(expand_constructor=True)])
    with _transform.PassContext(opt_level=3):
        mod = seq(mod)
    expected = relay.fromtext(r"""
        v0.0.4
        type List[A] {
            Cons(A, List[A]),
            Nil,
        }
        def @main[A]() -> (fn(A, List[A]) -> List[A]) {
            fn [A](%x: A, %xs: List[A]) -> List[A] {
                Cons(%x, %xs)
            }
        }
    """)
    relay.analysis.assert_graph_equal(mod['main'], expected['main'])
def test_eta_expand_global_var():
    mod = relay.fromtext(r"""
        v0.0.4
        def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
            %x
        }
        def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) {
            @aux
        }
    """)
    seq = tvm.transform.Sequential(
        [_transform.EtaExpand(expand_global_var=True)])
    with tvm.transform.PassContext(opt_level=3):
        mod = seq(mod)
    expected = relay.fromtext(r"""
        v0.0.4
        def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
            %x
        }
        def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) {
            fn (%x: Tensor[(), int32]) -> Tensor[(), int32] {
                @aux(%x)
            }
        }
    """)
    tvm.ir.assert_structural_equal(mod['main'],
                                   expected['main'],
                                   map_free_vars=True)
def test_eta_expand_global_var():
    mod = relay.fromtext(r"""
        v0.0.4
        def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
            %x
        }
        def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) {
            @aux
        }
    """)
    seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
    with _transform.PassContext(opt_level=3):
        mod = seq(mod)
    expected = relay.fromtext(r"""
        v0.0.4
        def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] {
            %x
        }
        def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) {
            fn (%x: Tensor[(), int32]) -> Tensor[(), int32] {
                @aux(%x)
            }
        }
    """)
    relay.analysis.assert_graph_equal(mod['main'], expected['main'])
Пример #4
0
def test_tensor_type():
    assert alpha_equal(
        relay.fromtext("let %_ : Tensor[(), float32] = (); ()"),
        relay.Let(
            relay.Var("_", relay.TensorType((), "float32")),
            UNIT,
            UNIT
        )
    )

    assert alpha_equal(
        relay.fromtext("let %_ : Tensor[(1,), float32] = (); ()"),
        relay.Let(
            relay.Var("_", relay.TensorType((1,), "float32")),
            UNIT,
            UNIT
        )
    )

    assert alpha_equal(
        relay.fromtext("let %_ : Tensor[(1, 1), float32] = (); ()"),
        relay.Let(
            relay.Var("_", relay.TensorType((1, 1), "float32")),
            UNIT,
            UNIT
        )
    )
Пример #5
0
def test_function_type():
    assert alpha_equal(
        relay.fromtext("""
            let %_: fn () -> int32 = fn () -> int32 { 0 }; ()
            """),
        relay.Let(relay.Var("_", relay.FuncType([], int32, [], [])),
                  relay.Function([], relay.const(0), int32, []), UNIT))

    assert alpha_equal(
        relay.fromtext("""
            let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; ()
            """),
        relay.Let(
            relay.Var("_", relay.FuncType([int32], int32, [], [])),
            relay.Function([relay.Var("x", int32)], relay.const(0), int32, []),
            UNIT))

    assert alpha_equal(
        relay.fromtext("""
            let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; ()
            """),
        relay.Let(
            relay.Var("_", relay.FuncType([int32, int32], int32, [], [])),
            relay.Function([relay.Var("x", int32),
                            relay.Var("y", int32)], relay.const(0), int32, []),
            UNIT))
def test_eta_expand_constructor():
    mod = relay.fromtext(r"""
        v0.0.4
        type List[A] {
            Cons(A, List[A]),
            Nil,
        }
        def @main[A]() -> (fn(A, List[A]) -> List[A]) {
            Cons
        }
    """)
    seq = tvm.transform.Sequential(
        [_transform.EtaExpand(expand_constructor=True)])
    with tvm.transform.PassContext(opt_level=3):
        mod = seq(mod)
    expected = relay.fromtext(r"""
        v0.0.4
        type List[A] {
            Cons(A, List[A]),
            Nil,
        }
        def @main[A]() -> (fn(A, List[A]) -> List[A]) {
            fn [A](%x: A, %xs: List[A]) -> List[A] {
                Cons(%x, %xs)
            }
        }
    """)
    tvm.ir.assert_structural_equal(mod['main'],
                                   expected['main'],
                                   map_free_vars=True)
Пример #7
0
def test_tuple():
    assert alpha_equal(relay.fromtext("()"), relay.Tuple([]))

    assert alpha_equal(relay.fromtext("(0,)"), relay.Tuple([relay.const(0)]))

    assert alpha_equal(relay.fromtext("(0, 1)"), relay.Tuple([relay.const(0), relay.const(1)]))

    assert alpha_equal(relay.fromtext("(0, 1, 2)"), relay.Tuple([relay.const(0), relay.const(1), relay.const(2)]))
Пример #8
0
def test_ifelse_scope():
    relay.fromtext(SEMVER + """
        if (True) {
            let %x = ();
            ()
        } else {
            %x
        }
        """)
Пример #9
0
def test_ifelse_scope():
    relay.fromtext(
        SEMVER+
        """
        if (True) {
            let %x = ();
            ()
        } else {
            %x
        }
        """
    )
Пример #10
0
def test_comments():
    assert alpha_equal(
        relay.fromtext("""
            // This is a line comment!
            ()
        """), UNIT)

    assert alpha_equal(
        relay.fromtext("""
            /* This is a block comment!
               This is still a block comment!
            */
            ()
        """), UNIT)
Пример #11
0
def test_recursive_call():
    id_defn = relay.fromtext(SEMVER + """
        def @id(%x: int32) -> int32 {
            @id(%x)
        }
        """)
    assert isinstance(id_defn, relay.Module)
Пример #12
0
def test_defn():
    id_defn = relay.fromtext(SEMVER + """
        def @id(%x: int32) -> int32 {
            %x
        }
        """)
    assert isinstance(id_defn, relay.Module)
Пример #13
0
def test_func():
    # 0 args
    assert alpha_equal(relay.fromtext("fn () { 0 }"),
                       relay.Function([], relay.const(0), None, []))

    # 1 arg
    assert alpha_equal(relay.fromtext("fn (%x) { %x }"),
                       relay.Function([X], X, None, []))

    # 2 args
    assert alpha_equal(relay.fromtext("fn (%x, %y) { %x + %y }"),
                       relay.Function([X, Y], relay.add(X, Y), None, []))

    # annotations
    assert alpha_equal(relay.fromtext("fn (%x: int32) -> int32 { %x }"),
                       relay.Function([X_ANNO], X_ANNO, int32, []))
Пример #14
0
def test_seq():
    assert alpha_equal(
        relay.fromtext("(); ()"),
        relay.Let(
            _,
            UNIT,
            UNIT)
    )

    assert alpha_equal(
        relay.fromtext("let %_ = { 1 }; ()"),
        relay.Let(
            X,
            relay.const(1),
            UNIT
        )
    )
Пример #15
0
def test_ifelse():
    assert alpha_equal(
        relay.fromtext("""
        if (True) {
            0
        } else {
            1
        }
        """), relay.If(relay.const(True), relay.const(0), relay.const(1)))
Пример #16
0
def test_recursive_call():
    id_defn = relay.fromtext(
        SEMVER+
        """
        def @id(%x: int32) -> int32 {
            @id(%x)
        }
        """)
    assert isinstance(id_defn, relay.Module)
Пример #17
0
def test_defn():
    id_defn = relay.fromtext(
        SEMVER+
        """
        def @id(%x: int32) -> int32 {
            %x
        }
        """)
    assert isinstance(id_defn, relay.Module)
Пример #18
0
def test_inf_loop_case():
    code = """
v0.0.4
type Arith[A] {
    Zero,
    Const(A),
    Plus(Arith[A], Arith[A])
}

def @shallow_opt[A](%a: Arith[A]) -> Arith[A] {
    match (%a) {
        Plus(Zero, %r) => %r,
        Plus(%l, Zero) => %l,
        _ => %a
    }
}
"""
    relay.fromtext(code)
Пример #19
0
def test_incomplete_type():
    assert alpha_equal(
        relay.fromtext("let %_ : _ = (); ()"),
        relay.Let(
            _,
            UNIT,
            UNIT
        )
    )
Пример #20
0
def test_let():
    assert alpha_equal(
        relay.fromtext("let %x = 1; ()"),
        relay.Let(
            X,
            relay.const(1),
            UNIT
        )
    )
Пример #21
0
def astext(p, unify_free_vars=False):
    txt = p.astext()
    if isinstance(p, Expr) and free_vars(p):
        return txt
    x = relay.fromtext(txt)
    if unify_free_vars:
        tvm.ir.assert_structural_equal(x, p, map_free_vars=True)
    else:
        tvm.ir.assert_structural_equal(x, p)
    return txt
Пример #22
0
def astext(p, unify_free_vars=False):
    txt = p.astext()
    if isinstance(p, Expr) and free_vars(p):
        return txt
    x = relay.fromtext(txt)
    if unify_free_vars:
        assert_graph_equal(x, p)
    else:
        assert_alpha_equal(x, p)
    return txt
Пример #23
0
def test_tuple_type():
    assert alpha_equal(
        relay.fromtext("""
        let %_: () = (); ()
        """), relay.Let(relay.Var("_", relay.TupleType([])), UNIT, UNIT))

    assert alpha_equal(
        relay.fromtext("""
        let %_: (int32,) = (0,); ()
        """),
        relay.Let(relay.Var("_", relay.TupleType([int32])),
                  relay.Tuple([relay.const(0)]), UNIT))

    assert alpha_equal(
        relay.fromtext("""
        let %_: (int32, int32) = (0, 1); ()
        """),
        relay.Let(relay.Var("_", relay.TupleType([int32, int32])),
                  relay.Tuple([relay.const(0), relay.const(1)]), UNIT))
Пример #24
0
def test_int_literal():
    assert isinstance(relay.fromtext(SEMVER+"1"), relay.Constant)
    assert isinstance(relay.fromtext(SEMVER+"1").data, tvm.ndarray.NDArray)

    assert get_scalar(relay.fromtext(SEMVER+"1")) == 1
    assert get_scalar(relay.fromtext(SEMVER+"10")) == 10
    assert get_scalar(relay.fromtext(SEMVER+"0")) == 0
    assert get_scalar(relay.fromtext(SEMVER+"-100")) == -100
    assert get_scalar(relay.fromtext(SEMVER+"-05")) == -5
Пример #25
0
def test_int_literal():
    assert isinstance(relay.fromtext("1"), relay.Constant)
    assert isinstance(relay.fromtext("1").data, tvm.ndarray.NDArray)
    
    assert get_scalar(relay.fromtext("1")) == 1
    assert get_scalar(relay.fromtext("10")) == 10
    assert get_scalar(relay.fromtext("0")) == 0
    assert get_scalar(relay.fromtext("-100")) == -100
    assert get_scalar(relay.fromtext("-05")) == -5
Пример #26
0
def test_vars():
    # temp vars won't work b/c they start with a digit
    # # temp var
    # temp_var = relay.fromtext("%1")
    # assert isinstance(temp_var, relay.Var)
    # assert temp_var.name == "1"

    # var
    var = relay.fromtext(SEMVER+"let %foo = (); %foo")
    assert isinstance(var.body, relay.Var)
    assert var.body.name_hint == "foo"

    # global var
    global_var = relay.fromtext(SEMVER+"@foo")
    assert isinstance(global_var, relay.GlobalVar)
    assert global_var.name_hint == "foo"

    # operator id
    op = relay.fromtext(SEMVER+"foo")
    assert isinstance(op, relay.Op)
    assert op.name == "foo"
Пример #27
0
def test_vars():
    # temp vars won't work b/c they start with a digit
    # # temp var
    # temp_var = relay.fromtext("%1")
    # assert isinstance(temp_var, relay.Var)
    # assert temp_var.name == "1"

    # var
    var = relay.fromtext("let %foo = (); %foo")
    assert isinstance(var.body, relay.Var)
    assert var.body.name_hint == "foo"

    # global var
    global_var = relay.fromtext("@foo")
    assert isinstance(global_var, relay.GlobalVar)
    assert global_var.name_hint == "foo"

    # operator id
    op = relay.fromtext("foo")
    assert isinstance(op, relay.Op)
    assert op.name == "foo"
Пример #28
0
def test_unapplied_constructor():
    type_def_str = r"""
type List[A] {
  Cons(A, List[A]),
  Nil,
}
    """
    main_def_str = r"""
def @main[A]() -> fn (A, List[A]) -> List[A] {
  Cons
}
    """
    mod = relay.fromtext(SEMVER + type_def_str + main_def_str)
    mod_str = str(mod)
    # ensure constructors are printed correctly in type definitions (with their
    # signature) and as exprs (without their signature)
    assert type_def_str.strip() in mod_str
    assert main_def_str.strip() in mod_str
Пример #29
0
                "numeric_type": "fixed_point",
                "is_signed": True,
                "width": 32,
                "frac_width": 16,
            },
        }

    return memories


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Lower Relay IR to Calyx.")
    parser.add_argument("file", help="Path to the Relay IR.")

    args = parser.parse_args()
    if args.file is None:
        raise Exception(
            "The TVM Relay visitor requires a file containing the Relay IR."
        )

    with open(args.file, "r") as file:
        relay_ir = file.read()
    assert (
        "v0.0.4" in relay_ir
    ), "TVM Requires `v0.0.4` at the top of the Relay IR file."

    relay_ir = relay.fromtext(relay_ir)
    print(emit_calyx(relay_ir))
Пример #30
0
def test_let_global_var():
    relay.fromtext(SEMVER+"let @x = 1; ()")
Пример #31
0
def test_negative():
    assert isinstance(relay.fromtext(SEMVER+"let %x = 1; -%x").body, relay.Call)
    assert get_scalar(relay.fromtext(SEMVER+"--10")) == 10
    assert get_scalar(relay.fromtext(SEMVER+"---10")) == -10
Пример #32
0
def test_builtin_types():
    for builtin_type in TYPES:
        relay.fromtext("let %_ : {} = (); ()".format(builtin_type))
Пример #33
0
def test_graph_wrong_order():
    relay.fromtext(SEMVER+"%1 = (); %1")
Пример #34
0
def test_call():
    # select right function to call: simple ident case
    id_func = relay.Var("id")
    assert alpha_equal(
        relay.fromtext(
        """
        let %id = fn (%x) { %x };
        10 * %id(10)
        """
        ),
        relay.Let(
            id_func,
            relay.Function([X], X, None, []),
            relay.multiply(relay.const(10), relay.Call(id_func, [relay.const(10)]))
        )
    )

    # 0 args
    constant = relay.Var("constant")
    assert alpha_equal(
        relay.fromtext(
        """
        let %constant = fn () { 0 };
        %constant()
        """
        ),
        relay.Let(
            constant,
            relay.Function([], relay.const(0), None, []),
            relay.Call(constant, [], None, None)
        )
    )

    # 1 arg
    id_var = relay.Var("id")
    assert alpha_equal(
        relay.fromtext(
            """
            let %id = fn (%x) { %x };
            %id(1)
            """
        ),
        relay.Let(
            id_var,
            relay.Function([X], X, None, []),
            relay.Call(id_var, [relay.const(1)], None, None)
        )
    )

    # 2 args
    multiply = relay.Var("multiply")
    assert alpha_equal(
        relay.fromtext(
        """
        let %multiply = fn (%x, %y) { %x * %y };
        %multiply(0, 0)
        """
        ),
        relay.Let(
            multiply,
            relay.Function(
                [X, Y],
                relay.multiply(X, Y),
                None,
                []
            ),
            relay.Call(multiply, [relay.const(0), relay.const(0)], None, None)
        )
    )

    # anonymous function
    assert alpha_equal(
        relay.fromtext(
        """
        (fn (%x) { %x })(0)
        """
        ),
        relay.Call(
            relay.Function(
                [X],
                X,
                None,
                []
            ),
            [relay.const(0)],
            None,
            None
        )
    )

    # curried function
    curried_mult = relay.Var("curried_mult")
    alpha_equal(
        relay.fromtext(
            """
            let %curried_mult =
                fn (%x) {
                fn (%y) {
                    %x * %y
                }
                };
            %curried_mult(0);
            %curried_mult(0)(0)
            """
        ),
        relay.Let(
            curried_mult,
            relay.Function(
                [X],
                relay.Function(
                    [Y],
                    relay.multiply(X, Y),
                    None,
                    []
                ),
                None,
                []
            ),
            relay.Let(
                _,
                relay.Call(curried_mult, [relay.const(0)], None, None),
                relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None)
            )
        )
    )

    # op
    alpha_equal(
        relay.fromtext("abs(1)"),
        relay.Call(relay.op.get("abs"), [relay.const(1)], None, None)
    )
Пример #35
0
def test_parens():
    assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"(1 * 1) + 1"))
    assert not alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"1 * (1 + 1)"))
Пример #36
0
def roundtrip(expr):
    x = relay.fromtext(expr.astext())
    assert_graph_equal(x, expr)
Пример #37
0
def roundtrip(expr):
    assert_alpha_equal(relay.fromtext(str(expr)), expr)
Пример #38
0
def test_builtin_types():
    for builtin_type in TYPES:
        relay.fromtext(SEMVER+"let %_ : {} = (); ()".format(builtin_type))
Пример #39
0
def parses_as(code, expr):
    # type: (str, relay.Expr) -> bool
    return alpha_equal(relay.fromtext(SEMVER + "\n" + code), expr)
Пример #40
0
def parse_text(code):
    x = relay.fromtext(SEMVER + "\n" + code)
    roundtrip(x)
    return x
Пример #41
0
def test_float_literal():
    assert get_scalar(relay.fromtext(SEMVER+"1.0")) == 1.0
    assert isclose(get_scalar(relay.fromtext(SEMVER+"1.56667")), 1.56667)
    assert get_scalar(relay.fromtext(SEMVER+"0.0")) == 0.0
    assert get_scalar(relay.fromtext(SEMVER+"-10.0")) == -10.0

    # scientific notation
    assert isclose(get_scalar(relay.fromtext(SEMVER+"1e-1")), 1e-1)
    assert get_scalar(relay.fromtext(SEMVER+"1e+1")) == 1e+1
    assert isclose(get_scalar(relay.fromtext(SEMVER+"1E-1")), 1E-1)
    assert get_scalar(relay.fromtext(SEMVER+"1E+1")) == 1E+1
    assert isclose(get_scalar(relay.fromtext(SEMVER+"1.0e-1")), 1.0e-1)
    assert get_scalar(relay.fromtext(SEMVER+"1.0e+1")) == 1.0e+1
    assert isclose(get_scalar(relay.fromtext(SEMVER+"1.0E-1")), 1.0E-1)
    assert get_scalar(relay.fromtext(SEMVER+"1.0E+1")) == 1.0E+1
Пример #42
0
def parse_text(code):
    expr = relay.fromtext(SEMVER + "\n" + code)
    roundtrip(expr)
    return expr
Пример #43
0
def test_let_op():
    relay.fromtext(SEMVER+"let x = 1; ()")
Пример #44
0
def test_bool_literal():
    assert get_scalar(relay.fromtext(SEMVER+"True")) == True
    assert get_scalar(relay.fromtext(SEMVER+"False")) == False
Пример #45
0
def test_op_assoc():
    assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1 < 1 == 1"), relay.fromtext(SEMVER+"(((1 * 1) + 1) < 1) == 1"))
    assert alpha_equal(relay.fromtext(SEMVER+"1 == 1 < 1 + 1 * 1"), relay.fromtext(SEMVER+"1 == (1 < (1 + (1 * 1)))"))