def test_bounded(random_len=150, pools=[PoolInfo("default", {}, 65535), PoolInfo("slow", {})]): """Tests two pools, one is bounded and one is not limited""" random.seed(0) mem_range = [BufferInfo(str(i), random.randrange(1, 65535), pools) for i in range(random_len)] for mr in mem_range: pr = random.choice(mem_range) while pr in (*mr.conflicts, mr): pr = random.choice(mem_range) mr.set_conflicts([*mr.conflicts, pr]) pr.set_conflicts([*pr.conflicts, mr]) fusmp_algo = tvm.get_global_func("tir.usmp.algo.hill_climb") result_map = fusmp_algo(mem_range, 0) _verify_all_conflicts(result_map)
def run_intervals(intervals, tolerance=0): """Helper to run intervals""" expected_mem = find_maximum_from_intervals(intervals) pools = [WorkspacePoolInfo("default", [])] buffers = [] # populate for i, (start, stop, size) in enumerate(intervals): buf = BufferInfo(str(i), size, pools) # buf.set_pool_candidates( ["default"] ) buffers.append(buf) # intersect for i, (i_start, i_stop, _) in enumerate(intervals): conflicts = set() for j, (j_start, j_stop, _) in enumerate(intervals): start = min(i_start, j_start) stop = max(i_stop, j_stop) i_dur = i_stop - i_start + 1 j_dur = j_stop - j_start + 1 if i != j and (stop - start + 1 < i_dur + j_dur): conflicts.add(buffers[j]) buffers[i].set_conflicts([c for c in sorted(conflicts, key=lambda c: c.name_hint)]) result = {} for (alg, params) in [ ("tir.usmp.algo.hill_climb", (expected_mem,)), ("tir.usmp.algo.greedy_by_size", (expected_mem,)), ]: fusmp_algo = tvm.get_global_func(alg) print("\n", "started", alg) buffer_info_arr = fusmp_algo(buffers, *params) print() _verify_all_conflicts(buffer_info_arr) result[alg], msg = _check_max_workspace_size( buffer_info_arr, pools[0], expected_mem, tolerance ) if not result[alg]: print(alg, msg) return result