def test_eta_expand_constructor(): mod = relay.fromtext(r""" v0.0.4 type List[A] { Cons(A, List[A]), Nil, } def @main[A]() -> (fn(A, List[A]) -> List[A]) { Cons } """) seq = _transform.Sequential( [_transform.EtaExpand(expand_constructor=True)]) with _transform.PassContext(opt_level=3): mod = seq(mod) expected = relay.fromtext(r""" v0.0.4 type List[A] { Cons(A, List[A]), Nil, } def @main[A]() -> (fn(A, List[A]) -> List[A]) { fn [A](%x: A, %xs: List[A]) -> List[A] { Cons(%x, %xs) } } """) relay.analysis.assert_graph_equal(mod['main'], expected['main'])
def test_eta_expand_global_var(): mod = relay.fromtext(r""" v0.0.4 def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] { %x } def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { @aux } """) seq = tvm.transform.Sequential( [_transform.EtaExpand(expand_global_var=True)]) with tvm.transform.PassContext(opt_level=3): mod = seq(mod) expected = relay.fromtext(r""" v0.0.4 def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] { %x } def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { fn (%x: Tensor[(), int32]) -> Tensor[(), int32] { @aux(%x) } } """) tvm.ir.assert_structural_equal(mod['main'], expected['main'], map_free_vars=True)
def test_eta_expand_global_var(): mod = relay.fromtext(r""" v0.0.4 def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] { %x } def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { @aux } """) seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)]) with _transform.PassContext(opt_level=3): mod = seq(mod) expected = relay.fromtext(r""" v0.0.4 def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] { %x } def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { fn (%x: Tensor[(), int32]) -> Tensor[(), int32] { @aux(%x) } } """) relay.analysis.assert_graph_equal(mod['main'], expected['main'])
def test_tensor_type(): assert alpha_equal( relay.fromtext("let %_ : Tensor[(), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((), "float32")), UNIT, UNIT ) ) assert alpha_equal( relay.fromtext("let %_ : Tensor[(1,), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((1,), "float32")), UNIT, UNIT ) ) assert alpha_equal( relay.fromtext("let %_ : Tensor[(1, 1), float32] = (); ()"), relay.Let( relay.Var("_", relay.TensorType((1, 1), "float32")), UNIT, UNIT ) )
def test_function_type(): assert alpha_equal( relay.fromtext(""" let %_: fn () -> int32 = fn () -> int32 { 0 }; () """), relay.Let(relay.Var("_", relay.FuncType([], int32, [], [])), relay.Function([], relay.const(0), int32, []), UNIT)) assert alpha_equal( relay.fromtext(""" let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () """), relay.Let( relay.Var("_", relay.FuncType([int32], int32, [], [])), relay.Function([relay.Var("x", int32)], relay.const(0), int32, []), UNIT)) assert alpha_equal( relay.fromtext(""" let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () """), relay.Let( relay.Var("_", relay.FuncType([int32, int32], int32, [], [])), relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []), UNIT))
def test_eta_expand_constructor(): mod = relay.fromtext(r""" v0.0.4 type List[A] { Cons(A, List[A]), Nil, } def @main[A]() -> (fn(A, List[A]) -> List[A]) { Cons } """) seq = tvm.transform.Sequential( [_transform.EtaExpand(expand_constructor=True)]) with tvm.transform.PassContext(opt_level=3): mod = seq(mod) expected = relay.fromtext(r""" v0.0.4 type List[A] { Cons(A, List[A]), Nil, } def @main[A]() -> (fn(A, List[A]) -> List[A]) { fn [A](%x: A, %xs: List[A]) -> List[A] { Cons(%x, %xs) } } """) tvm.ir.assert_structural_equal(mod['main'], expected['main'], map_free_vars=True)
def test_tuple(): assert alpha_equal(relay.fromtext("()"), relay.Tuple([])) assert alpha_equal(relay.fromtext("(0,)"), relay.Tuple([relay.const(0)])) assert alpha_equal(relay.fromtext("(0, 1)"), relay.Tuple([relay.const(0), relay.const(1)])) assert alpha_equal(relay.fromtext("(0, 1, 2)"), relay.Tuple([relay.const(0), relay.const(1), relay.const(2)]))
def test_ifelse_scope(): relay.fromtext(SEMVER + """ if (True) { let %x = (); () } else { %x } """)
def test_ifelse_scope(): relay.fromtext( SEMVER+ """ if (True) { let %x = (); () } else { %x } """ )
def test_comments(): assert alpha_equal( relay.fromtext(""" // This is a line comment! () """), UNIT) assert alpha_equal( relay.fromtext(""" /* This is a block comment! This is still a block comment! */ () """), UNIT)
def test_recursive_call(): id_defn = relay.fromtext(SEMVER + """ def @id(%x: int32) -> int32 { @id(%x) } """) assert isinstance(id_defn, relay.Module)
def test_defn(): id_defn = relay.fromtext(SEMVER + """ def @id(%x: int32) -> int32 { %x } """) assert isinstance(id_defn, relay.Module)
def test_func(): # 0 args assert alpha_equal(relay.fromtext("fn () { 0 }"), relay.Function([], relay.const(0), None, [])) # 1 arg assert alpha_equal(relay.fromtext("fn (%x) { %x }"), relay.Function([X], X, None, [])) # 2 args assert alpha_equal(relay.fromtext("fn (%x, %y) { %x + %y }"), relay.Function([X, Y], relay.add(X, Y), None, [])) # annotations assert alpha_equal(relay.fromtext("fn (%x: int32) -> int32 { %x }"), relay.Function([X_ANNO], X_ANNO, int32, []))
def test_seq(): assert alpha_equal( relay.fromtext("(); ()"), relay.Let( _, UNIT, UNIT) ) assert alpha_equal( relay.fromtext("let %_ = { 1 }; ()"), relay.Let( X, relay.const(1), UNIT ) )
def test_ifelse(): assert alpha_equal( relay.fromtext(""" if (True) { 0 } else { 1 } """), relay.If(relay.const(True), relay.const(0), relay.const(1)))
def test_recursive_call(): id_defn = relay.fromtext( SEMVER+ """ def @id(%x: int32) -> int32 { @id(%x) } """) assert isinstance(id_defn, relay.Module)
def test_defn(): id_defn = relay.fromtext( SEMVER+ """ def @id(%x: int32) -> int32 { %x } """) assert isinstance(id_defn, relay.Module)
def test_inf_loop_case(): code = """ v0.0.4 type Arith[A] { Zero, Const(A), Plus(Arith[A], Arith[A]) } def @shallow_opt[A](%a: Arith[A]) -> Arith[A] { match (%a) { Plus(Zero, %r) => %r, Plus(%l, Zero) => %l, _ => %a } } """ relay.fromtext(code)
def test_incomplete_type(): assert alpha_equal( relay.fromtext("let %_ : _ = (); ()"), relay.Let( _, UNIT, UNIT ) )
def test_let(): assert alpha_equal( relay.fromtext("let %x = 1; ()"), relay.Let( X, relay.const(1), UNIT ) )
def astext(p, unify_free_vars=False): txt = p.astext() if isinstance(p, Expr) and free_vars(p): return txt x = relay.fromtext(txt) if unify_free_vars: tvm.ir.assert_structural_equal(x, p, map_free_vars=True) else: tvm.ir.assert_structural_equal(x, p) return txt
def astext(p, unify_free_vars=False): txt = p.astext() if isinstance(p, Expr) and free_vars(p): return txt x = relay.fromtext(txt) if unify_free_vars: assert_graph_equal(x, p) else: assert_alpha_equal(x, p) return txt
def test_tuple_type(): assert alpha_equal( relay.fromtext(""" let %_: () = (); () """), relay.Let(relay.Var("_", relay.TupleType([])), UNIT, UNIT)) assert alpha_equal( relay.fromtext(""" let %_: (int32,) = (0,); () """), relay.Let(relay.Var("_", relay.TupleType([int32])), relay.Tuple([relay.const(0)]), UNIT)) assert alpha_equal( relay.fromtext(""" let %_: (int32, int32) = (0, 1); () """), relay.Let(relay.Var("_", relay.TupleType([int32, int32])), relay.Tuple([relay.const(0), relay.const(1)]), UNIT))
def test_int_literal(): assert isinstance(relay.fromtext(SEMVER+"1"), relay.Constant) assert isinstance(relay.fromtext(SEMVER+"1").data, tvm.ndarray.NDArray) assert get_scalar(relay.fromtext(SEMVER+"1")) == 1 assert get_scalar(relay.fromtext(SEMVER+"10")) == 10 assert get_scalar(relay.fromtext(SEMVER+"0")) == 0 assert get_scalar(relay.fromtext(SEMVER+"-100")) == -100 assert get_scalar(relay.fromtext(SEMVER+"-05")) == -5
def test_int_literal(): assert isinstance(relay.fromtext("1"), relay.Constant) assert isinstance(relay.fromtext("1").data, tvm.ndarray.NDArray) assert get_scalar(relay.fromtext("1")) == 1 assert get_scalar(relay.fromtext("10")) == 10 assert get_scalar(relay.fromtext("0")) == 0 assert get_scalar(relay.fromtext("-100")) == -100 assert get_scalar(relay.fromtext("-05")) == -5
def test_vars(): # temp vars won't work b/c they start with a digit # # temp var # temp_var = relay.fromtext("%1") # assert isinstance(temp_var, relay.Var) # assert temp_var.name == "1" # var var = relay.fromtext(SEMVER+"let %foo = (); %foo") assert isinstance(var.body, relay.Var) assert var.body.name_hint == "foo" # global var global_var = relay.fromtext(SEMVER+"@foo") assert isinstance(global_var, relay.GlobalVar) assert global_var.name_hint == "foo" # operator id op = relay.fromtext(SEMVER+"foo") assert isinstance(op, relay.Op) assert op.name == "foo"
def test_vars(): # temp vars won't work b/c they start with a digit # # temp var # temp_var = relay.fromtext("%1") # assert isinstance(temp_var, relay.Var) # assert temp_var.name == "1" # var var = relay.fromtext("let %foo = (); %foo") assert isinstance(var.body, relay.Var) assert var.body.name_hint == "foo" # global var global_var = relay.fromtext("@foo") assert isinstance(global_var, relay.GlobalVar) assert global_var.name_hint == "foo" # operator id op = relay.fromtext("foo") assert isinstance(op, relay.Op) assert op.name == "foo"
def test_unapplied_constructor(): type_def_str = r""" type List[A] { Cons(A, List[A]), Nil, } """ main_def_str = r""" def @main[A]() -> fn (A, List[A]) -> List[A] { Cons } """ mod = relay.fromtext(SEMVER + type_def_str + main_def_str) mod_str = str(mod) # ensure constructors are printed correctly in type definitions (with their # signature) and as exprs (without their signature) assert type_def_str.strip() in mod_str assert main_def_str.strip() in mod_str
"numeric_type": "fixed_point", "is_signed": True, "width": 32, "frac_width": 16, }, } return memories if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Lower Relay IR to Calyx.") parser.add_argument("file", help="Path to the Relay IR.") args = parser.parse_args() if args.file is None: raise Exception( "The TVM Relay visitor requires a file containing the Relay IR." ) with open(args.file, "r") as file: relay_ir = file.read() assert ( "v0.0.4" in relay_ir ), "TVM Requires `v0.0.4` at the top of the Relay IR file." relay_ir = relay.fromtext(relay_ir) print(emit_calyx(relay_ir))
def test_let_global_var(): relay.fromtext(SEMVER+"let @x = 1; ()")
def test_negative(): assert isinstance(relay.fromtext(SEMVER+"let %x = 1; -%x").body, relay.Call) assert get_scalar(relay.fromtext(SEMVER+"--10")) == 10 assert get_scalar(relay.fromtext(SEMVER+"---10")) == -10
def test_builtin_types(): for builtin_type in TYPES: relay.fromtext("let %_ : {} = (); ()".format(builtin_type))
def test_graph_wrong_order(): relay.fromtext(SEMVER+"%1 = (); %1")
def test_call(): # select right function to call: simple ident case id_func = relay.Var("id") assert alpha_equal( relay.fromtext( """ let %id = fn (%x) { %x }; 10 * %id(10) """ ), relay.Let( id_func, relay.Function([X], X, None, []), relay.multiply(relay.const(10), relay.Call(id_func, [relay.const(10)])) ) ) # 0 args constant = relay.Var("constant") assert alpha_equal( relay.fromtext( """ let %constant = fn () { 0 }; %constant() """ ), relay.Let( constant, relay.Function([], relay.const(0), None, []), relay.Call(constant, [], None, None) ) ) # 1 arg id_var = relay.Var("id") assert alpha_equal( relay.fromtext( """ let %id = fn (%x) { %x }; %id(1) """ ), relay.Let( id_var, relay.Function([X], X, None, []), relay.Call(id_var, [relay.const(1)], None, None) ) ) # 2 args multiply = relay.Var("multiply") assert alpha_equal( relay.fromtext( """ let %multiply = fn (%x, %y) { %x * %y }; %multiply(0, 0) """ ), relay.Let( multiply, relay.Function( [X, Y], relay.multiply(X, Y), None, [] ), relay.Call(multiply, [relay.const(0), relay.const(0)], None, None) ) ) # anonymous function assert alpha_equal( relay.fromtext( """ (fn (%x) { %x })(0) """ ), relay.Call( relay.Function( [X], X, None, [] ), [relay.const(0)], None, None ) ) # curried function curried_mult = relay.Var("curried_mult") alpha_equal( relay.fromtext( """ let %curried_mult = fn (%x) { fn (%y) { %x * %y } }; %curried_mult(0); %curried_mult(0)(0) """ ), relay.Let( curried_mult, relay.Function( [X], relay.Function( [Y], relay.multiply(X, Y), None, [] ), None, [] ), relay.Let( _, relay.Call(curried_mult, [relay.const(0)], None, None), relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) ) ) ) # op alpha_equal( relay.fromtext("abs(1)"), relay.Call(relay.op.get("abs"), [relay.const(1)], None, None) )
def test_parens(): assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"(1 * 1) + 1")) assert not alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"1 * (1 + 1)"))
def roundtrip(expr): x = relay.fromtext(expr.astext()) assert_graph_equal(x, expr)
def roundtrip(expr): assert_alpha_equal(relay.fromtext(str(expr)), expr)
def test_builtin_types(): for builtin_type in TYPES: relay.fromtext(SEMVER+"let %_ : {} = (); ()".format(builtin_type))
def parses_as(code, expr): # type: (str, relay.Expr) -> bool return alpha_equal(relay.fromtext(SEMVER + "\n" + code), expr)
def parse_text(code): x = relay.fromtext(SEMVER + "\n" + code) roundtrip(x) return x
def test_float_literal(): assert get_scalar(relay.fromtext(SEMVER+"1.0")) == 1.0 assert isclose(get_scalar(relay.fromtext(SEMVER+"1.56667")), 1.56667) assert get_scalar(relay.fromtext(SEMVER+"0.0")) == 0.0 assert get_scalar(relay.fromtext(SEMVER+"-10.0")) == -10.0 # scientific notation assert isclose(get_scalar(relay.fromtext(SEMVER+"1e-1")), 1e-1) assert get_scalar(relay.fromtext(SEMVER+"1e+1")) == 1e+1 assert isclose(get_scalar(relay.fromtext(SEMVER+"1E-1")), 1E-1) assert get_scalar(relay.fromtext(SEMVER+"1E+1")) == 1E+1 assert isclose(get_scalar(relay.fromtext(SEMVER+"1.0e-1")), 1.0e-1) assert get_scalar(relay.fromtext(SEMVER+"1.0e+1")) == 1.0e+1 assert isclose(get_scalar(relay.fromtext(SEMVER+"1.0E-1")), 1.0E-1) assert get_scalar(relay.fromtext(SEMVER+"1.0E+1")) == 1.0E+1
def parse_text(code): expr = relay.fromtext(SEMVER + "\n" + code) roundtrip(expr) return expr
def test_let_op(): relay.fromtext(SEMVER+"let x = 1; ()")
def test_bool_literal(): assert get_scalar(relay.fromtext(SEMVER+"True")) == True assert get_scalar(relay.fromtext(SEMVER+"False")) == False
def test_op_assoc(): assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1 < 1 == 1"), relay.fromtext(SEMVER+"(((1 * 1) + 1) < 1) == 1")) assert alpha_equal(relay.fromtext(SEMVER+"1 == 1 < 1 + 1 * 1"), relay.fromtext(SEMVER+"1 == (1 < (1 + (1 * 1)))"))