예제 #1
0
def test_controller_info_actualized():
    # Check that the num_threads attribute reflects the actual state of the threadpools
    controller = ThreadpoolController()
    original_info = controller.info()

    with threadpool_limits(limits=1):
        assert all(lib_controller.num_threads == 1
                   for lib_controller in controller.lib_controllers)

    assert controller.info() == original_info
예제 #2
0
def test_threadpool_controller_as_decorator():
    # Check that using the decorator can be nested and is restricted to the scope of
    # the decorated function.
    controller = ThreadpoolController()
    original_info = controller.info()

    if any(info["num_threads"] < 2 for info in original_info):
        pytest.skip("Test requires at least 2 CPUs on host machine")
    if not controller.select(user_api="blas"):
        pytest.skip("Requires a blas runtime.")

    def check_blas_num_threads(expected_num_threads):
        blas_controller = ThreadpoolController().select(user_api="blas")
        assert all(lib_controller.num_threads == expected_num_threads
                   for lib_controller in blas_controller.lib_controllers)

    @controller.wrap(limits=1, user_api="blas")
    def outer_func():
        check_blas_num_threads(expected_num_threads=1)
        inner_func()
        check_blas_num_threads(expected_num_threads=1)

    @controller.wrap(limits=2, user_api="blas")
    def inner_func():
        check_blas_num_threads(expected_num_threads=2)

    outer_func()

    assert ThreadpoolController().info() == original_info
예제 #3
0
def test_threadpool_controller_info():
    # Check that all keys expected for the private api are in the dicts
    # returned by the `info` methods
    controller = ThreadpoolController()

    assert threadpool_info() == [
        lib_controller.info() for lib_controller in controller.lib_controllers
    ]
    assert controller.info() == [
        lib_controller.info() for lib_controller in controller.lib_controllers
    ]

    for lib_controller_dict in controller.info():
        assert "user_api" in lib_controller_dict
        assert "internal_api" in lib_controller_dict
        assert "prefix" in lib_controller_dict
        assert "filepath" in lib_controller_dict
        assert "version" in lib_controller_dict
        assert "num_threads" in lib_controller_dict

        if lib_controller_dict["internal_api"] in ("mkl", "blis", "openblas"):
            assert "threading_layer" in lib_controller_dict
예제 #4
0
def test_threadpool_limits_by_prefix(prefix, limit):
    # Check that the maximum number of threads can be set by prefix
    controller = ThreadpoolController()
    original_info = controller.info()

    controller_matching_prefix = controller.select(prefix=prefix)
    if not controller_matching_prefix:
        pytest.skip(f"Requires {prefix} runtime")

    with threadpool_limits(limits={prefix: limit}):
        for lib_controller in controller_matching_prefix.lib_controllers:
            if is_old_openblas(lib_controller):
                continue
            # threadpool_limits only sets an upper bound on the number of
            # threads.
            assert 0 < lib_controller.num_threads <= limit
    assert ThreadpoolController().info() == original_info
예제 #5
0
def test_set_threadpool_limits_by_api(user_api, limit):
    # Check that the maximum number of threads can be set by user_api
    controller = ThreadpoolController()
    original_info = controller.info()

    if user_api is None:
        controller_matching_api = controller
    else:
        controller_matching_api = controller.select(user_api=user_api)
    if not controller_matching_api:
        user_apis = _ALL_USER_APIS if user_api is None else [user_api]
        pytest.skip(f"Requires a library which api is in {user_apis}")

    with threadpool_limits(limits=limit, user_api=user_api):
        for lib_controller in controller_matching_api.lib_controllers:
            if is_old_openblas(lib_controller):
                continue
            # threadpool_limits only sets an upper bound on the number of
            # threads.
            assert 0 < lib_controller.num_threads <= limit

    assert ThreadpoolController().info() == original_info
예제 #6
0
def test_nested_limits():
    # Check that exiting the context manager properly restores the original limits even
    # when nested.
    controller = ThreadpoolController()
    original_info = controller.info()

    if any(info["num_threads"] < 2 for info in original_info):
        pytest.skip("Test requires at least 2 CPUs on host machine")

    def check_num_threads(expected_num_threads):
        assert all(
            lib_controller.num_threads == expected_num_threads
            for lib_controller in ThreadpoolController().lib_controllers)

    with controller.limit(limits=1):
        check_num_threads(expected_num_threads=1)

        with controller.limit(limits=2):
            check_num_threads(expected_num_threads=2)

        check_num_threads(expected_num_threads=1)

    assert ThreadpoolController().info() == original_info