def test_no_pool_error(): target = Target("c") tiny_workspace_pool = usmp_utils.PoolInfo( pool_name="tiny_workspace", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, size_hint_bytes=10, ) bi_a = usmp_utils.BufferInfo(name_hint="bi_a", size_bytes=10, pool_candidates=[tiny_workspace_pool]) bi_b = usmp_utils.BufferInfo(name_hint="bi_b", size_bytes=10, pool_candidates=[tiny_workspace_pool]) bi_c = usmp_utils.BufferInfo(name_hint="bi_c", size_bytes=10, pool_candidates=[tiny_workspace_pool]) bi_a.set_conflicts([bi_b]) bi_b.set_conflicts([bi_c]) bi_c.set_conflicts([bi_a]) buffer_info_arr = [bi_a, bi_b, bi_c] fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.greedy_by_size") with pytest.raises( tvm.TVMError, match= "TVM USMP Error: the space available in the provided pools exceeded" ): buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0)
def _test(): target = Target("c") global_workspace_pool = usmp_utils.PoolInfo( pool_name="global_workspace", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, ) bi_a = usmp_utils.BufferInfo(name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool]) bi_b = usmp_utils.BufferInfo(name_hint="bi_b", size_bytes=10, pool_candidates=[global_workspace_pool]) bi_c = usmp_utils.BufferInfo(name_hint="bi_c", size_bytes=10, pool_candidates=[global_workspace_pool]) bi_a.set_conflicts([bi_b, bi_c]) bi_b.set_conflicts([bi_c, bi_a]) bi_c.set_conflicts([bi_a, bi_b]) buffer_info_arr = [bi_a, bi_b, bi_c] fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0) assert buffer_pool_allocations[bi_a].byte_offset == 20 assert buffer_pool_allocations[bi_b].byte_offset == 10 assert buffer_pool_allocations[bi_c].byte_offset == 0
def test_linear(algorithm, workspace_size): """ The test case here represent BufferInfo objects that could get generated for a linear sequence such as : (Op A) | bi_a | (Op B) | bi_b | . . . (Op F) | bi_f """ target = Target("c") global_workspace_pool = usmp_utils.PoolInfo( pool_name="global_workspace", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, ) bi_a = usmp_utils.BufferInfo(name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool]) bi_b = usmp_utils.BufferInfo(name_hint="bi_b", size_bytes=20, pool_candidates=[global_workspace_pool]) bi_c = usmp_utils.BufferInfo(name_hint="bi_c", size_bytes=100, pool_candidates=[global_workspace_pool]) bi_d = usmp_utils.BufferInfo(name_hint="bi_d", size_bytes=40, pool_candidates=[global_workspace_pool]) bi_e = usmp_utils.BufferInfo(name_hint="bi_e", size_bytes=50, pool_candidates=[global_workspace_pool]) bi_f = usmp_utils.BufferInfo(name_hint="bi_f", size_bytes=50, pool_candidates=[global_workspace_pool]) # Creating conflicts for a linear graph bi_a.set_conflicts([bi_b]) bi_b.set_conflicts([bi_a, bi_c]) bi_c.set_conflicts([bi_b, bi_d]) bi_d.set_conflicts([bi_c, bi_e]) bi_e.set_conflicts([bi_d, bi_f]) bi_f.set_conflicts([bi_e]) buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f] fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0) _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)
def test_fanout(algorithm, workspace_size): """ The test case here represent BufferInfo objects that could get generated for a fanout topology such as : (Op A) | bi_a --------- | | (Op B) (Op C) | | bi_b bi_c | | (Op D) (Op E) | | bi_d bi_e | | (Op F) ------ | bi_f | (Op G) | bi_g """ target = Target("c") global_workspace_pool = WorkspacePoolInfo( "global_workspace", targets=[target], ) bi_a = usmp_utils.BufferInfo( name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool] ) bi_b = usmp_utils.BufferInfo( name_hint="bi_b", size_bytes=20, pool_candidates=[global_workspace_pool] ) bi_c = usmp_utils.BufferInfo( name_hint="bi_c", size_bytes=100, pool_candidates=[global_workspace_pool] ) bi_d = usmp_utils.BufferInfo( name_hint="bi_d", size_bytes=40, pool_candidates=[global_workspace_pool] ) bi_e = usmp_utils.BufferInfo( name_hint="bi_e", size_bytes=50, pool_candidates=[global_workspace_pool] ) bi_f = usmp_utils.BufferInfo( name_hint="bi_f", size_bytes=60, pool_candidates=[global_workspace_pool] ) bi_g = usmp_utils.BufferInfo( name_hint="bi_g", size_bytes=70, pool_candidates=[global_workspace_pool] ) # Creating conflicts for a linear graph bi_a.set_conflicts([bi_b, bi_c]) bi_b.set_conflicts([bi_a, bi_c, bi_e]) bi_c.set_conflicts([bi_e, bi_a, bi_b, bi_d]) bi_d.set_conflicts([bi_b, bi_f, bi_c, bi_e]) bi_e.set_conflicts([bi_c, bi_f, bi_b, bi_d]) bi_f.set_conflicts([bi_d, bi_e, bi_f]) bi_g.set_conflicts([bi_f]) buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f, bi_g] fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0) _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)