Beispiel #1
0
def test_threadpool_controller_as_decorator():
    # Check that using the decorator can be nested and is restricted to the scope of
    # the decorated function.
    controller = ThreadpoolController()
    original_info = controller.info()

    if any(info["num_threads"] < 2 for info in original_info):
        pytest.skip("Test requires at least 2 CPUs on host machine")
    if not controller.select(user_api="blas"):
        pytest.skip("Requires a blas runtime.")

    def check_blas_num_threads(expected_num_threads):
        blas_controller = ThreadpoolController().select(user_api="blas")
        assert all(lib_controller.num_threads == expected_num_threads
                   for lib_controller in blas_controller.lib_controllers)

    @controller.wrap(limits=1, user_api="blas")
    def outer_func():
        check_blas_num_threads(expected_num_threads=1)
        inner_func()
        check_blas_num_threads(expected_num_threads=1)

    @controller.wrap(limits=2, user_api="blas")
    def inner_func():
        check_blas_num_threads(expected_num_threads=2)

    outer_func()

    assert ThreadpoolController().info() == original_info
Beispiel #2
0
def test_threadpool_limits_by_prefix(prefix, limit):
    # Check that the maximum number of threads can be set by prefix
    controller = ThreadpoolController()
    original_info = controller.info()

    controller_matching_prefix = controller.select(prefix=prefix)
    if not controller_matching_prefix:
        pytest.skip(f"Requires {prefix} runtime")

    with threadpool_limits(limits={prefix: limit}):
        for lib_controller in controller_matching_prefix.lib_controllers:
            if is_old_openblas(lib_controller):
                continue
            # threadpool_limits only sets an upper bound on the number of
            # threads.
            assert 0 < lib_controller.num_threads <= limit
    assert ThreadpoolController().info() == original_info
Beispiel #3
0
def test_set_threadpool_limits_by_api(user_api, limit):
    # Check that the maximum number of threads can be set by user_api
    controller = ThreadpoolController()
    original_info = controller.info()

    if user_api is None:
        controller_matching_api = controller
    else:
        controller_matching_api = controller.select(user_api=user_api)
    if not controller_matching_api:
        user_apis = _ALL_USER_APIS if user_api is None else [user_api]
        pytest.skip(f"Requires a library which api is in {user_apis}")

    with threadpool_limits(limits=limit, user_api=user_api):
        for lib_controller in controller_matching_api.lib_controllers:
            if is_old_openblas(lib_controller):
                continue
            # threadpool_limits only sets an upper bound on the number of
            # threads.
            assert 0 < lib_controller.num_threads <= limit

    assert ThreadpoolController().info() == original_info