def test_get_preload_options(enable_tcp, enable_infiniband_netdev, enable_nvlink): enable_infiniband, net_devices = enable_infiniband_netdev opts = get_preload_options( protocol="ucx", create_cuda_context=True, enable_tcp_over_ucx=enable_tcp, enable_infiniband=enable_infiniband, enable_nvlink=enable_nvlink, ucx_net_devices=net_devices, cuda_device_index=5, ) assert "preload" in opts assert opts["preload"] == ["dask_cuda.initialize"] assert "preload_argv" in opts assert "--create-cuda-context" in opts["preload_argv"] if enable_tcp: assert "--enable-tcp-over-ucx" in opts["preload_argv"] if enable_infiniband: assert "--enable-infiniband" in opts["preload_argv"] if callable(net_devices): dev = net_devices(5) assert str("--net-devices=" + dev) in opts["preload_argv"] elif isinstance(net_devices, str) and net_devices != "": assert str("--net-devices=" + net_devices) in opts["preload_argv"] if enable_nvlink: assert "--enable-nvlink" in opts["preload_argv"]
def test_get_preload_options_default(): pytest.importorskip("ucp") opts = get_preload_options( protocol="ucx", create_cuda_context=True, ) assert "preload" in opts assert opts["preload"] == ["dask_cuda.initialize"] assert "preload_argv" in opts assert opts["preload_argv"] == ["--create-cuda-context"]