def test_loop(): mod = Module() t = TypeVar("t") x = Var("x", t) loop = GlobalVar("loop") mod[loop] = Function([x], loop(x), t, [t]) res = dcpe(loop(const(1)), mod=mod) expected = Call(loop, [const(1)], None, [None]) assert alpha_equal(res, expected)
def test_loop(): mod = tvm.IRModule() t = TypeVar("t") x = Var("x", t) loop = GlobalVar("loop") mod[loop] = Function([x], loop(x), t, [t]) expected = Call(loop, [const(1)]) mod["main"] = Function([], expected) expected = mod["main"].body call = Function([], loop(const(1))) res = dcpe(call, mod=mod) assert tvm.ir.structural_equal(res.body, expected)
def test_loop(): mod = Module() t = TypeVar("t") x = Var("x", t) loop = GlobalVar("loop") mod[loop] = Function([x], loop(x), t, [t]) expected = Call(loop, [const(1)]) mod[mod.entry_func] = Function([], expected) expected = mod[mod.entry_func].body call = Function([], loop(const(1))) res = dcpe(call, mod=mod) assert alpha_equal(res.body, expected)
def visit_call(self, call): new_op = self.visit(call.op) new_args = [self.visit(arg) for arg in call.args] return Call(new_op, new_args, call.attrs)