Example #1
0
def test_register_backend_entrypoint():
    # Code adapted from pandas backend entry point testing
    # https://github.com/pandas-dev/pandas/blob/2470690b9f0826a8feb426927694fa3500c3e8d2/pandas/tests/plotting/test_backend.py#L50-L76

    dist = pkg_resources.get_distribution("distributed")
    if dist.module_path not in distributed.__file__:
        # We are running from a non-installed distributed, and this test is invalid
        pytest.skip("Testing a non-installed distributed")

    mod = types.ModuleType("dask_udp")
    mod.UDPBackend = lambda: 1
    sys.modules[mod.__name__] = mod

    entry_point_name = "distributed.comm.backends"
    backends_entry_map = pkg_resources.get_entry_map("distributed")
    if entry_point_name not in backends_entry_map:
        backends_entry_map[entry_point_name] = dict()
    backends_entry_map[entry_point_name]["udp"] = pkg_resources.EntryPoint(
        "udp", mod.__name__, attrs=["UDPBackend"], dist=dist)

    # The require is disabled here since particularly unit tests may install
    # dirty or dev versions which are conflicting with backend entrypoints if
    # they are demanding for exact, stable versions. This should not fail the
    # test
    result = get_backend("udp", require=False)
    assert result == 1
Example #2
0
def resolve_address(addr: str) -> str:
    """
    Apply scheme-specific address resolution to *addr*, replacing
    all symbolic references with concrete location specifiers.

    In practice, this can mean hostnames are resolved to IP addresses.

    >>> resolve_address('tcp://localhost:8786')
    'tcp://127.0.0.1:8786'
    """
    scheme, loc = parse_address(addr)
    backend = registry.get_backend(scheme)
    return unparse_address(scheme, backend.resolve_address(loc))
Example #3
0
async def test_inproc_handshakes_concurrently():
    async def handle_comm():
        pass

    async with listen("inproc://", handle_comm) as listener:
        addr = listener.listen_address
        scheme, loc = parse_address(addr)
        connector = get_backend(scheme).get_connector()

        comm1 = await connector.connect(loc)
        comm2 = await connector.connect(loc)
        await comm1.close()
        await comm2.close()
Example #4
0
def get_address_host(addr: str) -> str:
    """
    Return a hostname / IP address identifying the machine this address
    is located on.

    In contrast to get_address_host_port(), this function should always
    succeed for well-formed addresses.

    >>> get_address_host('tcp://1.2.3.4:80')
    '1.2.3.4'
    """
    scheme, loc = parse_address(addr)
    backend = registry.get_backend(scheme)
    return backend.get_address_host(loc)
Example #5
0
async def test_inproc_continues_listening_after_handshake_error():
    async def handle_comm():
        pass

    async with listen("inproc://", handle_comm) as listener:
        addr = listener.listen_address
        scheme, loc = parse_address(addr)
        connector = get_backend(scheme).get_connector()

        comm = await connector.connect(loc)
        await comm.close()

        comm = await connector.connect(loc)
        await comm.close()
Example #6
0
def get_local_address_for(addr: str) -> str:
    """
    Get a local listening address suitable for reaching *addr*.

    For instance, trying to reach an external TCP address will return
    a local TCP address that's routable to that external address.

    >>> get_local_address_for('tcp://8.8.8.8:1234')
    'tcp://192.168.1.68'
    >>> get_local_address_for('tcp://127.0.0.1:1234')
    'tcp://127.0.0.1'
    """
    scheme, loc = parse_address(addr)
    backend = registry.get_backend(scheme)
    return unparse_address(scheme, backend.get_local_address_for(loc))
Example #7
0
def get_address_host_port(addr: str, strict: bool = False) -> tuple[str, int]:
    """
    Get a (host, port) tuple out of the given address.
    For definition of strict check parse_address
    ValueError is raised if the address scheme doesn't allow extracting
    the requested information.

    >>> get_address_host_port('tcp://1.2.3.4:80')
    ('1.2.3.4', 80)
    >>> get_address_host_port('tcp://[::1]:80')
    ('::1', 80)
    """
    scheme, loc = parse_address(addr, strict=strict)
    backend = registry.get_backend(scheme)
    try:
        return backend.get_address_host_port(loc)
    except NotImplementedError:
        raise ValueError(
            f"don't know how to extract host and port for address {addr!r}"
        )
Example #8
0
def listen(addr, handle_comm, deserialize=True, **kwargs):
    """
    Create a listener object with the given parameters.  When its ``start()``
    method is called, the listener will listen on the given address
    (a URI such as ``tcp://0.0.0.0``) and call *handle_comm* with a
    ``Comm`` object for each incoming connection.

    *handle_comm* can be a regular function or a coroutine.
    """
    try:
        scheme, loc = parse_address(addr, strict=True)
    except ValueError:
        if kwargs.get("ssl_context"):
            addr = "tls://" + addr
        else:
            addr = "tcp://" + addr
        scheme, loc = parse_address(addr, strict=True)

    backend = registry.get_backend(scheme)

    return backend.get_listener(loc, handle_comm, deserialize, **kwargs)
Example #9
0
def test_register_backend_entrypoint():
    # Code adapted from pandas backend entry point testing
    # https://github.com/pandas-dev/pandas/blob/2470690b9f0826a8feb426927694fa3500c3e8d2/pandas/tests/plotting/test_backend.py#L50-L76

    dist = pkg_resources.get_distribution("distributed")
    if dist.module_path not in distributed.__file__:
        # We are running from a non-installed distributed, and this test is invalid
        pytest.skip("Testing a non-installed distributed")

    mod = types.ModuleType("dask_udp")
    mod.UDPBackend = lambda: 1
    sys.modules[mod.__name__] = mod

    entry_point_name = "distributed.comm.backends"
    backends_entry_map = pkg_resources.get_entry_map("distributed")
    if entry_point_name not in backends_entry_map:
        backends_entry_map[entry_point_name] = dict()
    backends_entry_map[entry_point_name]["udp"] = pkg_resources.EntryPoint(
        "udp", mod.__name__, attrs=["UDPBackend"], dist=dist)

    result = get_backend("udp")
    assert result == 1
def test_registered():
    assert "ucx" in backends
    backend = get_backend("ucx")
    assert isinstance(backend, ucx.UCXBackend)
Example #11
0
def test_registered():
    assert "ws" in backends
    backend = get_backend("ws")
    assert isinstance(backend, ws.WSBackend)
Example #12
0
def _get_backend_on_path(path):
    sys.path.append(os.fsdecode(path))
    return get_backend("udp")
Example #13
0
async def connect(
    addr, timeout=None, deserialize=True, handshake_overrides=None, **connection_args
):
    """
    Connect to the given address (a URI such as ``tcp://127.0.0.1:1234``)
    and yield a ``Comm`` object.  If the connection attempt fails, it is
    retried until the *timeout* is expired.
    """
    if timeout is None:
        timeout = dask.config.get("distributed.comm.timeouts.connect")
    timeout = parse_timedelta(timeout, default="seconds")

    scheme, loc = parse_address(addr)
    backend = registry.get_backend(scheme)
    connector = backend.get_connector()
    comm = None

    start = time()

    def time_left():
        deadline = start + timeout
        return max(0, deadline - time())

    backoff_base = 0.01
    attempt = 0

    # Prefer multiple small attempts than one long attempt. This should protect
    # primarily from DNS race conditions
    # gh3104, gh4176, gh4167
    intermediate_cap = timeout / 5
    active_exception = None
    while time_left() > 0:
        try:
            comm = await asyncio.wait_for(
                connector.connect(loc, deserialize=deserialize, **connection_args),
                timeout=min(intermediate_cap, time_left()),
            )
            break
        except FatalCommClosedError:
            raise
        # Note: CommClosed inherits from OSError
        except (asyncio.TimeoutError, OSError) as exc:
            active_exception = exc

            # As descibed above, the intermediate timeout is used to distributed
            # initial, bulk connect attempts homogeneously. In particular with
            # the jitter upon retries we should not be worred about overloading
            # any more DNS servers
            intermediate_cap = timeout
            # FullJitter see https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/

            upper_cap = min(time_left(), backoff_base * (2**attempt))
            backoff = random.uniform(0, upper_cap)
            attempt += 1
            logger.debug(
                "Could not connect to %s, waiting for %s before retrying", loc, backoff
            )
            await asyncio.sleep(backoff)
    else:
        raise OSError(
            f"Timed out trying to connect to {addr} after {timeout} s"
        ) from active_exception

    local_info = {
        **comm.handshake_info(),
        **(handshake_overrides or {}),
    }
    try:
        # This would be better, but connections leak if worker is closed quickly
        # write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
        handshake = await asyncio.wait_for(comm.read(), time_left())
        await asyncio.wait_for(comm.write(local_info), time_left())
    except Exception as exc:
        with suppress(Exception):
            await comm.close()
        raise OSError(
            f"Timed out during handshake while connecting to {addr} after {timeout} s"
        ) from exc

    comm.remote_info = handshake
    comm.remote_info["address"] = comm._peer_addr
    comm.local_info = local_info
    comm.local_info["address"] = comm._local_addr

    comm.handshake_options = comm.handshake_configuration(
        comm.local_info, comm.remote_info
    )
    return comm