def test_controller_info_actualized(): # Check that the num_threads attribute reflects the actual state of the threadpools controller = ThreadpoolController() original_info = controller.info() with threadpool_limits(limits=1): assert all(lib_controller.num_threads == 1 for lib_controller in controller.lib_controllers) assert controller.info() == original_info
def test_threadpool_controller_as_decorator(): # Check that using the decorator can be nested and is restricted to the scope of # the decorated function. controller = ThreadpoolController() original_info = controller.info() if any(info["num_threads"] < 2 for info in original_info): pytest.skip("Test requires at least 2 CPUs on host machine") if not controller.select(user_api="blas"): pytest.skip("Requires a blas runtime.") def check_blas_num_threads(expected_num_threads): blas_controller = ThreadpoolController().select(user_api="blas") assert all(lib_controller.num_threads == expected_num_threads for lib_controller in blas_controller.lib_controllers) @controller.wrap(limits=1, user_api="blas") def outer_func(): check_blas_num_threads(expected_num_threads=1) inner_func() check_blas_num_threads(expected_num_threads=1) @controller.wrap(limits=2, user_api="blas") def inner_func(): check_blas_num_threads(expected_num_threads=2) outer_func() assert ThreadpoolController().info() == original_info
def test_threadpool_controller_info(): # Check that all keys expected for the private api are in the dicts # returned by the `info` methods controller = ThreadpoolController() assert threadpool_info() == [ lib_controller.info() for lib_controller in controller.lib_controllers ] assert controller.info() == [ lib_controller.info() for lib_controller in controller.lib_controllers ] for lib_controller_dict in controller.info(): assert "user_api" in lib_controller_dict assert "internal_api" in lib_controller_dict assert "prefix" in lib_controller_dict assert "filepath" in lib_controller_dict assert "version" in lib_controller_dict assert "num_threads" in lib_controller_dict if lib_controller_dict["internal_api"] in ("mkl", "blis", "openblas"): assert "threading_layer" in lib_controller_dict
def test_threadpool_limits_by_prefix(prefix, limit): # Check that the maximum number of threads can be set by prefix controller = ThreadpoolController() original_info = controller.info() controller_matching_prefix = controller.select(prefix=prefix) if not controller_matching_prefix: pytest.skip(f"Requires {prefix} runtime") with threadpool_limits(limits={prefix: limit}): for lib_controller in controller_matching_prefix.lib_controllers: if is_old_openblas(lib_controller): continue # threadpool_limits only sets an upper bound on the number of # threads. assert 0 < lib_controller.num_threads <= limit assert ThreadpoolController().info() == original_info
def test_set_threadpool_limits_by_api(user_api, limit): # Check that the maximum number of threads can be set by user_api controller = ThreadpoolController() original_info = controller.info() if user_api is None: controller_matching_api = controller else: controller_matching_api = controller.select(user_api=user_api) if not controller_matching_api: user_apis = _ALL_USER_APIS if user_api is None else [user_api] pytest.skip(f"Requires a library which api is in {user_apis}") with threadpool_limits(limits=limit, user_api=user_api): for lib_controller in controller_matching_api.lib_controllers: if is_old_openblas(lib_controller): continue # threadpool_limits only sets an upper bound on the number of # threads. assert 0 < lib_controller.num_threads <= limit assert ThreadpoolController().info() == original_info
def test_nested_limits(): # Check that exiting the context manager properly restores the original limits even # when nested. controller = ThreadpoolController() original_info = controller.info() if any(info["num_threads"] < 2 for info in original_info): pytest.skip("Test requires at least 2 CPUs on host machine") def check_num_threads(expected_num_threads): assert all( lib_controller.num_threads == expected_num_threads for lib_controller in ThreadpoolController().lib_controllers) with controller.limit(limits=1): check_num_threads(expected_num_threads=1) with controller.limit(limits=2): check_num_threads(expected_num_threads=2) check_num_threads(expected_num_threads=1) assert ThreadpoolController().info() == original_info