def test_threadpool_controller_as_decorator(): # Check that using the decorator can be nested and is restricted to the scope of # the decorated function. controller = ThreadpoolController() original_info = controller.info() if any(info["num_threads"] < 2 for info in original_info): pytest.skip("Test requires at least 2 CPUs on host machine") if not controller.select(user_api="blas"): pytest.skip("Requires a blas runtime.") def check_blas_num_threads(expected_num_threads): blas_controller = ThreadpoolController().select(user_api="blas") assert all(lib_controller.num_threads == expected_num_threads for lib_controller in blas_controller.lib_controllers) @controller.wrap(limits=1, user_api="blas") def outer_func(): check_blas_num_threads(expected_num_threads=1) inner_func() check_blas_num_threads(expected_num_threads=1) @controller.wrap(limits=2, user_api="blas") def inner_func(): check_blas_num_threads(expected_num_threads=2) outer_func() assert ThreadpoolController().info() == original_info
def test_threadpool_limits_by_prefix(prefix, limit): # Check that the maximum number of threads can be set by prefix controller = ThreadpoolController() original_info = controller.info() controller_matching_prefix = controller.select(prefix=prefix) if not controller_matching_prefix: pytest.skip(f"Requires {prefix} runtime") with threadpool_limits(limits={prefix: limit}): for lib_controller in controller_matching_prefix.lib_controllers: if is_old_openblas(lib_controller): continue # threadpool_limits only sets an upper bound on the number of # threads. assert 0 < lib_controller.num_threads <= limit assert ThreadpoolController().info() == original_info
def test_set_threadpool_limits_by_api(user_api, limit): # Check that the maximum number of threads can be set by user_api controller = ThreadpoolController() original_info = controller.info() if user_api is None: controller_matching_api = controller else: controller_matching_api = controller.select(user_api=user_api) if not controller_matching_api: user_apis = _ALL_USER_APIS if user_api is None else [user_api] pytest.skip(f"Requires a library which api is in {user_apis}") with threadpool_limits(limits=limit, user_api=user_api): for lib_controller in controller_matching_api.lib_controllers: if is_old_openblas(lib_controller): continue # threadpool_limits only sets an upper bound on the number of # threads. assert 0 < lib_controller.num_threads <= limit assert ThreadpoolController().info() == original_info