Ejemplo n.º 1
0
    def connect(
        self,
        conn_str: str,
        job_config: JobConfig = None,
        secure: bool = False,
        metadata: List[Tuple[str, str]] = None,
        connection_retries: int = 3,
        namespace: str = None,
        *,
        ignore_version: bool = False,
        _credentials: Optional[grpc.ChannelCredentials] = None,
        ray_init_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """Connect the Ray Client to a server.

        Args:
            conn_str: Connection string, in the form "[host]:port"
            job_config: The job config of the server.
            secure: Whether to use a TLS secured gRPC channel
            metadata: gRPC metadata to send on connect
            connection_retries: number of connection attempts to make
            ignore_version: whether to ignore Python or Ray version mismatches.
                This should only be used for debugging purposes.

        Returns:
            Dictionary of connection info, e.g., {"num_clients": 1}.
        """
        # Delay imports until connect to avoid circular imports.
        from ray.util.client.worker import Worker

        if self.client_worker is not None:
            if self._connected_with_init:
                return
            raise Exception("ray.init() called, but ray client is already connected")
        if not self._inside_client_test:
            # If we're calling a client connect specifically and we're not
            # currently in client mode, ensure we are.
            _explicitly_enable_client_mode()
        if namespace is not None:
            job_config = job_config or JobConfig()
            job_config.set_ray_namespace(namespace)

        logging_level = ray_constants.LOGGER_LEVEL
        logging_format = ray_constants.LOGGER_FORMAT

        if ray_init_kwargs is not None:
            if ray_init_kwargs.get("logging_level") is not None:
                logging_level = ray_init_kwargs["logging_level"]
            if ray_init_kwargs.get("logging_format") is not None:
                logging_format = ray_init_kwargs["logging_format"]

        setup_logger(logging_level, logging_format)

        try:
            self.client_worker = Worker(
                conn_str,
                secure=secure,
                _credentials=_credentials,
                metadata=metadata,
                connection_retries=connection_retries,
            )
            self.api.worker = self.client_worker
            self.client_worker._server_init(job_config, ray_init_kwargs)
            conn_info = self.client_worker.connection_info()
            self._check_versions(conn_info, ignore_version)
            self._register_serializers()
            return conn_info
        except Exception:
            self.disconnect()
            raise
Ejemplo n.º 2
0
def main():
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--host",
                        type=str,
                        default="0.0.0.0",
                        help="Host IP to bind to")
    parser.add_argument("-p",
                        "--port",
                        type=int,
                        default=10001,
                        help="Port to bind to")
    parser.add_argument(
        "--mode",
        type=str,
        choices=["proxy", "legacy", "specific-server"],
        default="proxy",
    )
    parser.add_argument("--address",
                        required=False,
                        type=str,
                        help="Address to use to connect to Ray")
    parser.add_argument(
        "--redis-password",
        required=False,
        type=str,
        help="Password for connecting to Redis",
    )
    parser.add_argument(
        "--metrics-agent-port",
        required=False,
        type=int,
        default=0,
        help="The port to use for connecting to the runtime_env agent.",
    )
    args, _ = parser.parse_known_args()
    setup_logger(ray_constants.LOGGER_LEVEL, ray_constants.LOGGER_FORMAT)

    ray_connect_handler = create_ray_handler(args.address, args.redis_password)

    hostport = "%s:%d" % (args.host, args.port)
    logger.info(f"Starting Ray Client server on {hostport}")
    if args.mode == "proxy":
        server = serve_proxier(
            hostport,
            args.address,
            redis_password=args.redis_password,
            runtime_env_agent_port=args.metrics_agent_port,
        )
    else:
        server = serve(hostport, ray_connect_handler)

    try:
        idle_checks_remaining = TIMEOUT_FOR_SPECIFIC_SERVER_S
        while True:
            health_report = {
                "time": time.time(),
            }

            try:
                if not ray.experimental.internal_kv._internal_kv_initialized():
                    gcs_client = try_create_gcs_client(args.address,
                                                       args.redis_password)
                    ray.experimental.internal_kv._initialize_internal_kv(
                        gcs_client)
                ray.experimental.internal_kv._internal_kv_put(
                    "ray_client_server",
                    json.dumps(health_report),
                    namespace=ray_constants.KV_NAMESPACE_HEALTHCHECK,
                )
            except Exception as e:
                logger.error(f"[{args.mode}] Failed to put health check "
                             f"on {args.address}")
                logger.exception(e)

            time.sleep(1)
            if args.mode == "specific-server":
                if server.data_servicer.num_clients > 0:
                    idle_checks_remaining = TIMEOUT_FOR_SPECIFIC_SERVER_S
                else:
                    idle_checks_remaining -= 1
                if idle_checks_remaining == 0:
                    raise KeyboardInterrupt()
                if (idle_checks_remaining % 5 == 0 and idle_checks_remaining !=
                        TIMEOUT_FOR_SPECIFIC_SERVER_S):
                    logger.info(
                        f"{idle_checks_remaining} idle checks before shutdown."
                    )

    except KeyboardInterrupt:
        server.stop(0)
Ejemplo n.º 3
0
from ray._private.ray_logging import setup_logger
from ray._private.runtime_env.context import RuntimeEnvContext
from ray.core.generated.common_pb2 import Language
from ray.ray_constants import LOGGER_LEVEL, LOGGER_FORMAT

logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser(
    description=("Set up the environment for a Ray worker and launch the worker.")
)

parser.add_argument(
    "--serialized-runtime-env-context",
    type=str,
    help="the serialized runtime env context",
)

parser.add_argument("--language", type=str, help="the language type of the worker")

if __name__ == "__main__":
    setup_logger(LOGGER_LEVEL, LOGGER_FORMAT)
    args, remaining_args = parser.parse_known_args()
    # NOTE(edoakes): args.serialized_runtime_env_context is only None when
    # we're starting the main Ray client proxy server. That case should
    # probably not even go through this codepath.
    runtime_env_context = RuntimeEnvContext.deserialize(
        args.serialized_runtime_env_context or "{}"
    )
    runtime_env_context.exec_worker(remaining_args, Language.Value(args.language))