def test_pow(): mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat_iterate = mod.get_global_var("nat_iterate") shape = (10, 10) dtype = "float32" t = relay.TensorType(shape, dtype) x = relay.var("x", t) double = relay.Function([x], x + x) i = relay.var("i", t) func = relay.Function([i], nat_iterate(double, make_nat_expr(p, 3))(i)) mod["main"] = func mod = transform.InferType()(mod) mod["main"] = gradient(mod["main"], mod=mod) m = transform.InferType()(mod) back_func = m["main"] assert back_func.checked_type == relay.FuncType( [t], relay.TupleType([t, relay.TupleType([t])])) i_nd = rand(dtype, *shape) ex = create_executor(mod=mod) forward, (grad_i, ) = ex.evaluate(back_func)(i_nd) tvm.testing.assert_allclose(forward.numpy(), 8 * i_nd.numpy()) tvm.testing.assert_allclose(grad_i.numpy(), 8 * np.ones_like(grad_i.numpy()))
def test_update(): expected = list(range(10)) l = nil() # create zero initialized list for i in range(len(expected)): l = cons(make_nat_expr(0), l) # set value for i, v in enumerate(expected): l = update(l, relay.const(i), make_nat_expr(v)) got = [] for i in range(len(expected)): got.append(count(intrp.evaluate(nth(l, relay.const(i))))) assert got == expected
def test_foldr1(): a = relay.TypeVar("a") lhs = mod[p.foldr1].checked_type rhs = relay.FuncType([relay.FuncType([a, a], a), l(a)], a, [a]) assert lhs == rhs x = relay.Var("x") y = relay.Var("y") f = relay.Function([x, y], add(x, y)) res = intrp.evaluate( foldr1( f, cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))))) assert count(res) == 6
def test_unfoldl(): a = relay.TypeVar("a") b = relay.TypeVar("b") expected_type = relay.FuncType( [relay.FuncType([a], optional(relay.TupleType([a, b]))), a], l(b), [a, b] ) x = relay.Var("x", nat()) n = relay.Var("n", nat()) count_down = relay.Function( [x], relay.Match( x, [ relay.Clause( relay.PatternConstructor(s, [relay.PatternVar(n)]), some(relay.Tuple([n, x])) ), relay.Clause(relay.PatternConstructor(z, []), none()), ], ), ) res = intrp.evaluate(unfoldl(count_down, make_nat_expr(3))) unfolded = to_list(res) assert len(unfolded) == 3 assert count(unfolded[0]) == 1 assert count(unfolded[1]) == 2 assert count(unfolded[2]) == 3
def test_concat(): a = relay.TypeVar("a") assert prelude.mod[concat].checked_type == relay.FuncType( [rlist(a), rlist(a)], rlist(a), [a]) l1 = cons(make_nat_expr(prelude, 1), cons(make_nat_expr(prelude, 2), nil())) l2 = cons(make_nat_expr(prelude, 3), cons(make_nat_expr(prelude, 4), nil())) res = intrp.evaluate(concat(l1, l2)) catted = to_list(res) assert len(catted) == 4 assert count(catted[0]) == 1 assert count(catted[1]) == 2 assert count(catted[2]) == 3 assert count(catted[3]) == 4
def test_foldr(): a = relay.TypeVar("a") b = relay.TypeVar("b") lhs = mod[foldr].checked_type rhs = relay.FuncType([relay.FuncType([a, b], b), b, l(a)], b, [a, b]) assert lhs == rhs x = relay.Var("x") y = relay.Var("y") identity = relay.Function([x, y], cons(x, y)) res = intrp.evaluate(foldr(identity, nil(), cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))))) same = to_list(res) assert len(same) == 3 assert count(same[0]) == 1 and count(same[1]) == 2 and count(same[2]) == 3
def test_rev(): a = relay.TypeVar("a") assert prelude.mod[rev].checked_type == relay.FuncType([rlist(a)], rlist(a), [a]) res = intrp.evaluate( rev( cons( make_nat_expr(prelude, 1), cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), ) ) ) reversed = to_list(res) assert len(reversed) == 3 assert count(reversed[0]) == 3 assert count(reversed[1]) == 2 assert count(reversed[2]) == 1
def test_optional_matching(): x = relay.Var('x') y = relay.Var('y') v = relay.Var('v') condense = relay.Function( [x, y], relay.Match(x, [ relay.Clause(relay.PatternConstructor(some, [relay.PatternVar(v)]), cons(v, y)), relay.Clause(relay.PatternConstructor(none), y) ])) res = intrp.evaluate(foldr(condense, nil(), cons( some(make_nat_expr(3)), cons(none(), cons(some(make_nat_expr(1)), nil()))))) reduced = to_list(res) assert len(reduced) == 2 assert count(reduced[0]) == 3 assert count(reduced[1]) == 1
def test_foldl(): a = relay.TypeVar("a") b = relay.TypeVar("b") lhs = mod[foldl].checked_type rhs = relay.FuncType([relay.FuncType([a, b], a), a, l(b)], a, [a, b]) assert lhs == rhs x = relay.Var("x") y = relay.Var("y") rev_dup = relay.Function([y, x], cons(x, cons(x, y))) res = intrp.evaluate(foldl(rev_dup, nil(), cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))))) reversed = to_list(res) assert len(reversed) == 6 assert count(reversed[0]) == 3 and count(reversed[1]) == 3 assert count(reversed[2]) == 2 and count(reversed[3]) == 2 assert count(reversed[4]) == 1 and count(reversed[5]) == 1
def test_zip(): a = relay.TypeVar("a") b = relay.TypeVar("b") expected_type = relay.FuncType([rlist(a), rlist(b)], rlist(relay.TupleType([a, b])), [a, b]) assert prelude.mod[zip].checked_type == expected_type l1 = cons( make_nat_expr(prelude, 1), cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), ) l2 = cons( nil(), cons(cons(nil(), nil()), cons(cons(nil(), cons(nil(), nil())), nil()))) res = intrp.evaluate(zip(l1, l2)) zipped = to_list(res) assert len(zipped) == 3 assert count(zipped[0][0]) == 1 assert len(to_list(zipped[0][1])) == 0 assert count(zipped[1][0]) == 2 assert len(to_list(zipped[1][1])) == 1 assert count(zipped[2][0]) == 3 assert len(to_list(zipped[2][1])) == 2 # test truncation l3 = cons(make_nat_expr(prelude, 4), cons(make_nat_expr(prelude, 5), nil())) shorter_res = intrp.evaluate(zip(l3, l2)) truncated = to_list(shorter_res) assert len(truncated) == 2 assert count(truncated[0][0]) == 4 assert len(to_list(truncated[0][1])) == 0 assert count(truncated[1][0]) == 5 assert len(to_list(truncated[1][1])) == 1 l4 = cons(nil(), nil()) shortest_res = intrp.evaluate(zip(l3, l4)) singleton = to_list(shortest_res) assert len(singleton) == 1 assert count(singleton[0][0]) == 4 assert len(to_list(singleton[0][1])) == 0
def test_abs_diff(): # TODO(@M.K.): refactor using tuple pattern (not yet implemented) mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat, z, s = p.mod.get_type("nat") x = Var("x", nat()) y = Var("y", nat()) xp = Var("x'", nat()) yp = Var("y'", nat()) diff = GlobalVar("diff") y_z_case = Clause(PatternConstructor(z, []), x) y_s_case = Clause(PatternConstructor(s, [PatternVar(yp)]), diff(yp, xp)) x_z_case = Clause(PatternConstructor(z, []), y) x_s_case = Clause(PatternConstructor(s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case])) mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case])) orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 4))
def test_abs_diff(): # TODO(@M.K.): refactor using tuple pattern (not yet implemented) mod = Module() p = Prelude(mod) add_nat_definitions(p) nat = p.nat() x = Var("x", nat) y = Var("y", nat) xp = Var("x'", nat) yp = Var("y'", nat) diff = GlobalVar("diff") y_z_case = Clause(PatternConstructor(p.z, []), x) y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp)) x_z_case = Clause(PatternConstructor(p.z, []), y) x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case])) mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case])) orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) assert alpha_equal(res.body, make_nat_expr(p, 4))
def test_foldr1(): a = relay.TypeVar("a") lhs = prelude.mod[foldr1].checked_type rhs = relay.FuncType([relay.FuncType([a, a], a), rlist(a)], a, [a]) assert lhs == rhs x = relay.Var("x") y = relay.Var("y") f = relay.Function([x, y], add(x, y)) res = eval( foldr1( f, cons( make_nat_expr(prelude, 1), cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), ), )) assert count(res) == 6
def test_hd_tl(): expected = list(range(10)) l = nil() for i in reversed(expected): l = cons(make_nat_expr(i), l) got = [] for i in range(len(expected)): got.append(count(intrp.evaluate(hd(l)))) l = tl(l) assert got == expected
def test_filter(): a = relay.TypeVar("a") expected_type = relay.FuncType( [relay.FuncType([a], relay.scalar_type("bool")), l(a)], l(a), [a]) assert mod[filter].checked_type == expected_type x = relay.Var("x", nat()) greater_than_one = relay.Function( [x], relay.Match(x, [ relay.Clause( relay.PatternConstructor( s, [relay.PatternConstructor(s, [relay.PatternWildcard()])]), relay.const(True)), relay.Clause(relay.PatternWildcard(), relay.const(False)) ])) res = intrp.evaluate( filter( greater_than_one, cons( make_nat_expr(1), cons( make_nat_expr(1), cons( make_nat_expr(3), cons( make_nat_expr(1), cons(make_nat_expr(5), cons(make_nat_expr(1), nil())))))))) filtered = to_list(res) assert len(filtered) == 2 assert count(filtered[0]) == 3 assert count(filtered[1]) == 5
def test_nested_matches(): a = relay.TypeVar("a") # TODO(@jroesch): inference should be able to handle this one x = relay.Var("x", type_annotation=rlist(rlist(a))) y = relay.Var("y") w = relay.Var("w") h = relay.Var("h") t = relay.Var("t") flatten = relay.GlobalVar("flatten") # flatten could be written using a fold, but this way has nested matches inner_match = relay.Match( y, [ relay.Clause(relay.PatternConstructor(nil), flatten(w)), relay.Clause( relay.PatternConstructor( cons, [relay.PatternVar(h), relay.PatternVar(t)]), cons(h, flatten(cons(t, w))), ), ], ) prelude.mod[flatten] = relay.Function( [x], relay.Match( x, [ relay.Clause(relay.PatternConstructor(nil), nil()), relay.Clause( relay.PatternConstructor( cons, [relay.PatternVar(y), relay.PatternVar(w)]), inner_match, ), ], ), rlist(a), [a], ) first_list = cons( make_nat_expr(prelude, 1), cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), ) second_list = cons( make_nat_expr(prelude, 4), cons(make_nat_expr(prelude, 5), cons(make_nat_expr(prelude, 6), nil())), ) final_list = cons(first_list, cons(second_list, nil())) res = intrp.evaluate(flatten(final_list)) flat = to_list(res) assert len(flat) == 6 for i in range(6): assert count(flat[i]) == i + 1
def test_recursion(): mod = relay.Module() p = Prelude(mod) add_nat_definitions(p) shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) double = relay.Function([x], x + x) i = relay.var("i", t) func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) mod[mod.entry_func] = func mod[mod.entry_func] = to_cps(mod[mod.entry_func], mod=mod) mod[mod.entry_func] = un_cps(mod[mod.entry_func]) ex = create_executor(mod=mod) i_nd = rand(dtype, *shape) forward = ex.evaluate(mod.entry_func)(i_nd) tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
def test_pow(): mod = relay.Module() p = Prelude(mod) add_nat_definitions(p) shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) double = relay.Function([x], x + x) i = relay.var("i", t) func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) back_func = relay.ir_pass.infer_type(gradient(func, mod=mod), mod=mod) assert back_func.checked_type == relay.FuncType( [t], relay.TupleType([t, relay.TupleType([t])])) i_nd = rand(dtype, *shape) ex = create_executor(mod=mod) forward, (grad_i, ) = ex.evaluate(back_func)(i_nd) tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy()) tvm.testing.assert_allclose(grad_i.asnumpy(), 8 * np.ones_like(grad_i.asnumpy()))
def test_recursion(): mod = tvm.IRModule() p = Prelude(mod) p.mod.import_from_std("nat.rly") nat_iterate = p.mod.get_global_var("nat_iterate") shape = (10, 10) dtype = "float32" t = relay.TensorType(shape, dtype) x = relay.var("x", t) double = relay.Function([x], x + x) i = relay.var("i", t) func = relay.Function([i], nat_iterate(double, make_nat_expr(p, 3))(i)) mod["main"] = func mod = relay.transform.InferType()(mod) mod["main"] = to_cps(mod["main"], mod=mod) mod = relay.transform.InferType()(mod) mod["main"] = un_cps(mod["main"]) i_nd = rand(dtype, *shape) forward = create_executor(mod=mod).evaluate()(i_nd) tvm.testing.assert_allclose(forward.numpy(), 8 * i_nd.numpy())
def test_nested_matches(): a = relay.TypeVar("a") x = relay.Var("x") y = relay.Var("y") w = relay.Var("w") h = relay.Var("h") t = relay.Var("t") flatten = relay.GlobalVar("flatten") # flatten could be written using a fold, but this way has nested matches inner_match = relay.Match( y, [ relay.Clause(relay.PatternConstructor(nil), flatten(w)), relay.Clause( relay.PatternConstructor(cons, [relay.PatternVar(h), relay.PatternVar(t)]), cons(h, flatten(cons(t, w))), ), ], ) mod[flatten] = relay.Function( [x], relay.Match( x, [ relay.Clause(relay.PatternConstructor(nil), nil()), relay.Clause( relay.PatternConstructor(cons, [relay.PatternVar(y), relay.PatternVar(w)]), inner_match, ), ], ), l(a), [a], ) first_list = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))) second_list = cons(make_nat_expr(4), cons(make_nat_expr(5), cons(make_nat_expr(6), nil()))) final_list = cons(first_list, cons(second_list, nil())) res = intrp.evaluate(flatten(final_list)) flat = to_list(res) assert len(flat) == 6 for i in range(6): assert count(flat[i]) == i + 1
def test_iterate(): expr = relay.Call(iterate(double, relay.const(2)), [make_nat_expr(3)]) res = intrp.evaluate(relay.Function([], expr)()) assert count(res) == 12