コード例 #1
0
 async def connect(self,
                   address: str,
                   deserialize=True,
                   **connection_args) -> UCX:
     logger.debug("UCXConnector.connect: %s", address)
     ip, port = parse_host_port(address)
     init_once()
     try:
         ep = await ucp.create_endpoint(ip, port)
     except (
             ucp.exceptions.UCXCloseError,
             ucp.exceptions.UCXCanceled,
     ) + (
             getattr(ucp.exceptions, "UCXConnectionReset", ()),
             getattr(ucp.exceptions, "UCXNotConnected", ()),
             getattr(ucp.exceptions, "UCXUnreachable", ()),
     ):  # type: ignore
         raise CommClosedError(
             "Connection closed before handshake completed")
     return self.comm_class(
         ep,
         local_addr="",
         peer_addr=self.prefix + address,
         deserialize=deserialize,
     )
コード例 #2
0
    async def connect(self, address, deserialize=True, **connection_args):
        self._check_encryption(address, connection_args)
        ip, port = parse_host_port(address)
        kwargs = self._get_connect_args(**connection_args)

        try:
            stream = await self.client.connect(ip,
                                               port,
                                               max_buffer_size=MAX_BUFFER_SIZE,
                                               **kwargs)
            # Under certain circumstances tornado will have a closed connnection with an
            # error and not raise a StreamClosedError.
            #
            # This occurs with tornado 5.x and openssl 1.1+
            if stream.closed() and stream.error:
                raise StreamClosedError(stream.error)

        except StreamClosedError as e:
            # The socket connect() call failed
            convert_stream_closed_error(self, e)
        except SSLCertVerificationError as err:
            raise FatalCommClosedError(
                "TLS certificate does not match. Check your security settings. "
                "More info at https://distributed.dask.org/en/latest/tls.html"
            ) from err
        except SSLError as err:
            raise FatalCommClosedError() from err

        local_address = self.prefix + get_stream_address(stream)
        comm = self.comm_class(stream, local_address, self.prefix + address,
                               deserialize)

        return comm
コード例 #3
0
 def get_local_address_for(self, loc):
     host, port = parse_host_port(loc)
     host = ensure_ip(host)
     if ":" in host:
         local_host = get_ipv6(host)
     else:
         local_host = get_ip(host)
     return unparse_host_port(local_host, None)
コード例 #4
0
    async def connect(self, address, deserialize=True, **kwargs):
        loop = asyncio.get_running_loop()
        ip, port = parse_host_port(address)

        kwargs = self._get_extra_kwargs(address, **kwargs)
        transport, protocol = await loop.create_connection(
            DaskCommProtocol, ip, port, **kwargs)
        local_addr = self.prefix + protocol.local_addr
        peer_addr = self.prefix + address
        return self.comm_class(protocol,
                               local_addr,
                               peer_addr,
                               deserialize=deserialize)
コード例 #5
0
ファイル: core.py プロジェクト: sjl070707/dask-xgboost
def _train(client, params, data, labels, **kwargs):
    """
    Asynchronous version of train

    See Also
    --------
    train
    """
    # Break apart Dask.array/dataframe into chunks/parts
    data_parts = data.to_delayed()
    label_parts = labels.to_delayed()
    if isinstance(data_parts, np.ndarray):
        assert data_parts.shape[1] == 1
        data_parts = data_parts.flatten().tolist()
    if isinstance(label_parts, np.ndarray):
        assert label_parts.ndim == 1 or label_parts.shape[1] == 1
        label_parts = label_parts.flatten().tolist()

    # Arrange parts into pairs.  This enforces co-locality
    parts = list(map(delayed, zip(data_parts, label_parts)))
    parts = client.compute(parts)  # Start computation in the background
    yield _wait(parts)

    # Because XGBoost-python doesn't yet allow iterative training, we need to
    # find the locations of all chunks and map them to particular Dask workers
    who_has = yield client.scheduler.who_has(keys=[part.key for part in parts])
    worker_map = defaultdict(list)
    for key, workers in who_has.items():
        worker_map[first(workers)].append(key)

    ncores = yield client.scheduler.ncores()  # Number of cores per worker

    # Start the XGBoost tracker on the Dask scheduler
    host, port = parse_host_port(client.scheduler.address)
    env = yield client._run_on_scheduler(start_tracker,
                                         host.strip('/:'),
                                         len(worker_map))

    # Tell each worker to train on the chunks/parts that it has locally
    futures = [client.submit(train_part, env,
                             assoc(params, 'nthreads', ncores[worker]),
                             list_of_parts, workers=worker, **kwargs)
               for worker, list_of_parts in worker_map.items()]

    # Get the results, only one will be non-None
    results = yield client._gather(futures)
    result = [v for v in results if v][0]
    return result
コード例 #6
0
 def __init__(
     self,
     address,
     comm_handler,
     deserialize=True,
     allow_offload=True,
     default_host=None,
     default_port=0,
     **kwargs,
 ):
     self.ip, self.port = parse_host_port(address, default_port)
     self.default_host = default_host
     self.comm_handler = comm_handler
     self.deserialize = deserialize
     self.allow_offload = allow_offload
     self._extra_kwargs = self._get_extra_kwargs(address, **kwargs)
     self.bound_address = None
コード例 #7
0
 def __init__(
     self,
     address: str,
     comm_handler: None,
     deserialize=False,
     allow_offload=True,
     **connection_args,
 ):
     if not address.startswith("ucx"):
         address = "ucx://" + address
     self.ip, self._input_port = parse_host_port(address, default_port=0)
     self.comm_handler = comm_handler
     self.deserialize = deserialize
     self.allow_offload = allow_offload
     self._ep = None  # type: ucp.Endpoint
     self.ucp_server = None
     self.connection_args = connection_args
コード例 #8
0
async def _create_listeners(session_state, nworkers, rank):
    assert session_state["loop"] is asyncio.get_event_loop()
    assert "nworkers" not in session_state
    session_state["nworkers"] = nworkers
    assert "rank" not in session_state
    session_state["rank"] = rank

    async def server_handler(ep):
        peer_rank = await ep.read()
        session_state["eps"][peer_rank] = ep

    # We listen on the same protocol and address as the worker address
    protocol, address = parse_address(session_state["worker"].address)
    address = parse_host_port(address)[0]
    address = unparse_address(protocol, address)

    session_state["lf"] = distributed.comm.listen(address, server_handler)
    await session_state["lf"].start()
    return session_state["lf"].listen_address
コード例 #9
0
    def __init__(
        self,
        address: str,
        handler: Callable,
        deserialize=True,
        allow_offload=False,
        **connection_args,
    ):
        if not address.startswith(self.prefix):
            address = f"{self.prefix}{address}"

        self.ip, self.port = parse_host_port(address, default_port=0)
        self.handler = handler
        self.deserialize = deserialize
        self.allow_offload = allow_offload
        self.connection_args = connection_args
        self.bound_address = None
        self.new_comm_server = True
        self.server_args = self._get_server_args(**connection_args)
コード例 #10
0
 def __init__(
     self,
     address,
     comm_handler,
     deserialize=True,
     allow_offload=True,
     default_host=None,
     default_port=0,
     **connection_args,
 ):
     self._check_encryption(address, connection_args)
     self.ip, self.port = parse_host_port(address, default_port)
     self.default_host = default_host
     self.comm_handler = comm_handler
     self.deserialize = deserialize
     self.allow_offload = allow_offload
     self.server_args = self._get_server_args(**connection_args)
     self.tcp_server = None
     self.bound_address = None
コード例 #11
0
 def resolve_address(self, loc):
     host, port = parse_host_port(loc)
     return unparse_host_port(ensure_ip(host), port)
コード例 #12
0
 def get_address_host_port(self, loc):
     return parse_host_port(loc)
コード例 #13
0
 def parse_it(x):
     return parse_host_port(parse_address(x)[1])