def test_ast_build(): schedule = construct_schedule_tree() count_ast = [0] def inc_count_ast(node, build): count_ast[0] += 1 return node build = isl.ast_build() build_copy = build.set_at_each_domain(inc_count_ast) ast = build.node_from(schedule) assert (count_ast[0] == 0) count_ast[0] = 0 ast = build_copy.node_from(schedule) assert (count_ast[0] == 2) build = build_copy count_ast[0] = 0 ast = build.node_from(schedule) assert (count_ast[0] == 2) do_fail = True count_ast_fail = [0] def fail_inc_count_ast(node, build): count_ast_fail[0] += 1 if do_fail: raise Exception("fail") return node build = isl.ast_build() build = build.set_at_each_domain(fail_inc_count_ast) caught = False try: ast = build.node_from(schedule) except: caught = True assert (caught) assert (count_ast_fail[0] > 0) build_copy = build build_copy = build_copy.set_at_each_domain(inc_count_ast) count_ast[0] = 0 ast = build_copy.node_from(schedule) assert (count_ast[0] == 2) count_ast_fail[0] = 0 do_fail = False ast = build.node_from(schedule) assert (count_ast_fail[0] == 2) test_ast_build_unroll(schedule)
def test_ast_build(): schedule = construct_schedule_tree() count_ast = [0] def inc_count_ast(node, build): count_ast[0] += 1 return node build = isl.ast_build() build_copy = build.set_at_each_domain(inc_count_ast) ast = build.node_from(schedule) assert(count_ast[0] == 0) count_ast[0] = 0 ast = build_copy.node_from(schedule) assert(count_ast[0] == 2) build = build_copy count_ast[0] = 0 ast = build.node_from(schedule) assert(count_ast[0] == 2) do_fail = True; count_ast_fail = [0] def fail_inc_count_ast(node, build): count_ast_fail[0] += 1 if do_fail: raise "fail" return node build = isl.ast_build() build = build.set_at_each_domain(fail_inc_count_ast) caught = False try: ast = build.node_from(schedule) except: caught = True assert(caught) assert(count_ast_fail[0] > 0) build_copy = build build_copy = build_copy.set_at_each_domain(inc_count_ast) count_ast[0] = 0 ast = build_copy.node_from(schedule) assert(count_ast[0] == 2) count_ast_fail[0] = 0 do_fail = False; ast = build.node_from(schedule) assert(count_ast_fail[0] == 2) test_ast_build_unroll(schedule);
def __init__(self, ast_build=None): super().__init__() self.ast_build = ast_build or isl.ast_build() for attr in dir(self.ast_build): if not attr.startswith('set_'): continue name = attr[len('set_'):] if hasattr(self, name): self.ast_build = getattr(self.ast_build, attr)(getattr(self, name))
def test_ast_build_unroll(schedule): root = schedule.root() def mark_unroll(node): if type(node) == isl.schedule_node_band: node = node.member_set_ast_loop_unroll(0) return node root = root.map_descendant_bottom_up(mark_unroll) schedule = root.schedule() count_ast = [0] def inc_count_ast(node, build): count_ast[0] += 1 return node build = isl.ast_build() build = build.set_at_each_domain(inc_count_ast) ast = build.node_from(schedule) assert(count_ast[0] == 30)