Exemple #1
0
    def terminate_node(self, node_id):
        node = self._get_cached_node(node_id)
        if self.cache_stopped_nodes:
            if node.spot_instance_request_id:
                cli_logger.print(
                    "Terminating instance {} " +
                    cf.dimmed("(cannot stop spot instances, only terminate)"),
                    node_id)  # todo: show node name?

                cli_logger.old_info(
                    logger,
                    "AWSNodeProvider: terminating node {} (spot nodes cannot "
                    "be stopped, only terminated)", node_id)
                node.terminate()
            else:
                cli_logger.print("Stopping instance {} " + cf.dimmed(
                    "(to terminate instead, "
                    "set `cache_stopped_nodes: False` "
                    "under `provider` in the cluster configuration)"),
                                 node_id)  # todo: show node name?

                cli_logger.old_info(
                    logger,
                    "AWSNodeProvider: stopping node {}. To terminate nodes "
                    "on stop, set 'cache_stopped_nodes: False' in the "
                    "provider config.".format(node_id))
                node.stop()
        else:
            node.terminate()

        self.tag_cache.pop(node_id, None)
        self.tag_cache_pending.pop(node_id, None)
Exemple #2
0
        def do_sync(remote_path, local_path, allow_non_existing_paths=False):
            if allow_non_existing_paths and not os.path.exists(local_path):
                cli_logger.print("sync: {} does not exist. Skipping.",
                                 local_path)
                # Ignore missing source files. In the future we should support
                # the --delete-missing-args command to delete files that have
                # been removed
                return

            assert os.path.exists(local_path), local_path

            if os.path.isdir(local_path):
                if not local_path.endswith("/"):
                    local_path += "/"
                if not remote_path.endswith("/"):
                    remote_path += "/"

            with LogTimer(self.log_prefix +
                          "Synced {} to {}".format(local_path, remote_path)):
                is_docker = (self.docker_config
                             and self.docker_config["container_name"] != "")
                if not is_docker:
                    # The DockerCommandRunner handles this internally.
                    self.cmd_runner.run(
                        "mkdir -p {}".format(os.path.dirname(remote_path)),
                        run_env="host")
                sync_cmd(
                    local_path, remote_path, docker_mount_if_possible=True)

                if remote_path not in nolog_paths:
                    # todo: timed here?
                    cli_logger.print("{} from {}", cf.bold(remote_path),
                                     cf.bold(local_path))
Exemple #3
0
    def terminate_nodes(self, node_ids):
        if not node_ids:
            return
        if self.cache_stopped_nodes:
            spot_ids = []
            on_demand_ids = []

            for node_id in node_ids:
                if self._get_cached_node(node_id).spot_instance_request_id:
                    spot_ids += [node_id]
                else:
                    on_demand_ids += [node_id]

            if on_demand_ids:
                # todo: show node names?
                cli_logger.print(
                    "Stopping instances {} " + cf.dimmed(
                        "(to terminate instead, "
                        "set `cache_stopped_nodes: False` "
                        "under `provider` in the cluster configuration)"),
                    cli_logger.render_list(on_demand_ids))

                self.ec2.meta.client.stop_instances(InstanceIds=on_demand_ids)
            if spot_ids:
                cli_logger.print(
                    "Terminating instances {} " +
                    cf.dimmed("(cannot stop spot instances, only terminate)"),
                    cli_logger.render_list(spot_ids))

                self.ec2.meta.client.terminate_instances(InstanceIds=spot_ids)
        else:
            self.ec2.meta.client.terminate_instances(InstanceIds=node_ids)
Exemple #4
0
        def remaining_nodes():
            workers = provider.non_terminated_nodes(
                {TAG_RAY_NODE_KIND: NODE_KIND_WORKER})

            if keep_min_workers:
                min_workers = config.get("min_workers", 0)

                cli_logger.print(
                    "{} random worker nodes will not be shut down. " +
                    cf.dimmed("(due to {})"), cf.bold(min_workers),
                    cf.bold("--keep-min-workers"))
                cli_logger.old_info(logger,
                                    "teardown_cluster: Keeping {} nodes...",
                                    min_workers)

                workers = random.sample(workers, len(workers) - min_workers)

            # todo: it's weird to kill the head node but not all workers
            if workers_only:
                cli_logger.print(
                    "The head node will not be shut down. " +
                    cf.dimmed("(due to {})"), cf.bold("--workers-only"))

                return workers

            head = provider.non_terminated_nodes(
                {TAG_RAY_NODE_KIND: NODE_KIND_HEAD})

            return head + workers
Exemple #5
0
    def do_update(self):
        self.provider.set_node_tags(
            self.node_id, {TAG_RAY_NODE_STATUS: STATUS_WAITING_FOR_SSH})
        cli_logger.labeled_value("New status", STATUS_WAITING_FOR_SSH)

        deadline = time.time() + NODE_START_WAIT_S
        self.wait_ready(deadline)

        node_tags = self.provider.node_tags(self.node_id)
        logger.debug("Node tags: {}".format(str(node_tags)))

        if node_tags.get(TAG_RAY_RUNTIME_CONFIG) == self.runtime_hash:
            # When resuming from a stopped instance the runtime_hash may be the
            # same, but the container will not be started.
            self.cmd_runner.run_init(
                as_head=self.is_head_node, file_mounts=self.file_mounts)

        # runtime_hash will only change whenever the user restarts
        # or updates their cluster with `get_or_create_head_node`
        if node_tags.get(TAG_RAY_RUNTIME_CONFIG) == self.runtime_hash and (
                not self.file_mounts_contents_hash
                or node_tags.get(TAG_RAY_FILE_MOUNTS_CONTENTS) ==
                self.file_mounts_contents_hash):
            # todo: we lie in the confirmation message since
            # full setup might be cancelled here
            cli_logger.print(
                "Configuration already up to date, "
                "skipping file mounts, initalization and setup commands.",
                _numbered=("[]", "2-5", 6))
            cli_logger.old_info(logger,
                                "{}{} already up-to-date, skip to ray start",
                                self.log_prefix, self.node_id)
Exemple #6
0
    def terminate_node(self, node_id):
        node = self._get_cached_node(node_id)
        if self.cache_stopped_nodes:
            if node.spot_instance_request_id:
                cli_logger.print(
                    "Terminating instance {} " +
                    cf.dimmed("(cannot stop spot instances, only terminate)"),
                    node_id)  # todo: show node name?
                node.terminate()
            else:
                cli_logger.print("Stopping instance {} " + cf.dimmed(
                    "(to terminate instead, "
                    "set `cache_stopped_nodes: False` "
                    "under `provider` in the cluster configuration)"),
                                 node_id)  # todo: show node name?
                node.stop()
        else:
            node.terminate()

        # TODO (Alex): We are leaking the tag cache here. Naively, we would
        # want to just remove the cache entry here, but terminating can be
        # asyncrhonous or error, which would result in a use after free error.
        # If this leak becomes bad, we can garbage collect the tag cache when
        # the node cache is updated.
        pass
Exemple #7
0
def get_local_dump_archive(stream: bool = False,
                           output: Optional[str] = None,
                           logs: bool = True,
                           debug_state: bool = True,
                           pip: bool = True,
                           processes: bool = True,
                           processes_verbose: bool = False) -> Optional[str]:
    if stream and output:
        raise ValueError(
            "You can only use either `--output` or `--stream`, but not both.")

    parameters = GetParameters(
        logs=logs,
        debug_state=debug_state,
        pip=pip,
        processes=processes,
        processes_verbose=processes_verbose)

    with Archive() as archive:
        get_all_local_data(archive, parameters)

    tmp = archive.file

    if stream:
        with open(tmp, "rb") as fp:
            os.write(1, fp.read())
        os.remove(tmp)
        return None

    target = output or os.path.join(os.getcwd(), os.path.basename(tmp))
    os.rename(tmp, target)
    cli_logger.print(f"Created local data archive at {target}")

    return target
Exemple #8
0
def create_archive_for_local_and_remote_nodes(archive: Archive,
                                              remote_nodes: Sequence[Node],
                                              parameters: GetParameters):
    """Create an archive combining data from the local and remote nodes.

    This will parallelize calls to get data from remote nodes.

    Args:
        archive: Archive object to add data to.
        remote_nodes (Sequence[Node]): Sequence of remote nodes.
        parameters: Parameters (settings) for getting data.

    Returns:
        Open archive object.

    """
    if not archive.is_open:
        archive.open()

    try:
        create_and_add_local_data_to_local_archive(archive, parameters)
    except CommandFailed as exc:
        cli_logger.error(exc)

    create_archive_for_remote_nodes(archive, remote_nodes, parameters)

    cli_logger.print(f"Collected data from local node and {len(remote_nodes)} "
                     f"remote nodes.")
    return archive
Exemple #9
0
    def wait_ready(self, deadline):
        with cli_logger.group(
            "Waiting for SSH to become available", _numbered=("[]", 1, NUM_SETUP_STEPS)
        ):
            with LogTimer(self.log_prefix + "Got remote shell"):

                cli_logger.print("Running `{}` as a test.", cf.bold("uptime"))
                first_conn_refused_time = None
                while True:
                    if time.time() > deadline:
                        raise Exception("wait_ready timeout exceeded.")
                    if self.provider.is_terminated(self.node_id):
                        raise Exception(
                            "wait_ready aborting because node "
                            "detected as terminated."
                        )

                    try:
                        # Run outside of the container
                        self.cmd_runner.run("uptime", timeout=5, run_env="host")
                        cli_logger.success("Success.")
                        return True
                    except ProcessRunnerError as e:
                        first_conn_refused_time = cmd_output_util.handle_ssh_fails(
                            e,
                            first_conn_refused_time,
                            retry_interval=READY_CHECK_INTERVAL,
                        )
                        time.sleep(READY_CHECK_INTERVAL)
                    except Exception as e:
                        # TODO(maximsmol): we should not be ignoring
                        # exceptions if they get filtered properly
                        # (new style log + non-interactive shells)
                        #
                        # however threading this configuration state
                        # is a pain and I'm leaving it for later

                        retry_str = "(" + str(e) + ")"
                        if hasattr(e, "cmd"):
                            if isinstance(e.cmd, str):
                                cmd_ = e.cmd
                            elif isinstance(e.cmd, list):
                                cmd_ = " ".join(e.cmd)
                            else:
                                logger.debug(
                                    f"e.cmd type ({type(e.cmd)}) not list or str."
                                )
                                cmd_ = str(e.cmd)
                            retry_str = "(Exit Status {}): {}".format(
                                e.returncode, cmd_
                            )

                        cli_logger.print(
                            "SSH still not available {}, retrying in {} seconds.",
                            cf.dimmed(retry_str),
                            cf.bold(str(READY_CHECK_INTERVAL)),
                        )

                        time.sleep(READY_CHECK_INTERVAL)
Exemple #10
0
    def terminate_nodes(self, node_ids):
        if not node_ids:
            return

        terminate_instances_func = self.ec2.meta.client.terminate_instances
        stop_instances_func = self.ec2.meta.client.stop_instances

        # In some cases, this function stops some nodes, but terminates others.
        # Each of these requires a different EC2 API call. So, we use the
        # "nodes_to_terminate" dict below to keep track of exactly which API
        # call will be used to stop/terminate which set of nodes. The key is
        # the function to use, and the value is the list of nodes to terminate
        # with that function.
        nodes_to_terminate = {
            terminate_instances_func: [],
            stop_instances_func: []
        }

        if self.cache_stopped_nodes:
            spot_ids = []
            on_demand_ids = []

            for node_id in node_ids:
                if self._get_cached_node(node_id).spot_instance_request_id:
                    spot_ids += [node_id]
                else:
                    on_demand_ids += [node_id]

            if on_demand_ids:
                # todo: show node names?
                cli_logger.print(
                    "Stopping instances {} " + cf.dimmed(
                        "(to terminate instead, "
                        "set `cache_stopped_nodes: False` "
                        "under `provider` in the cluster configuration)"),
                    cli_logger.render_list(on_demand_ids),
                )

            if spot_ids:
                cli_logger.print(
                    "Terminating instances {} " +
                    cf.dimmed("(cannot stop spot instances, only terminate)"),
                    cli_logger.render_list(spot_ids),
                )

            nodes_to_terminate[stop_instances_func] = on_demand_ids
            nodes_to_terminate[terminate_instances_func] = spot_ids
        else:
            nodes_to_terminate[terminate_instances_func] = node_ids

        max_terminate_nodes = (self.max_terminate_nodes
                               if self.max_terminate_nodes is not None else
                               len(node_ids))

        for terminate_func, nodes in nodes_to_terminate.items():
            for start in range(0, len(nodes), max_terminate_nodes):
                terminate_func(InstanceIds=nodes[start:start +
                                                 max_terminate_nodes])
Exemple #11
0
def list(address: Optional[str]):
    """Lists all running jobs and their information.

    Example:
        ray job list
    """
    client = _get_sdk_client(address)
    # Set no_format to True because the logs may have unescaped "{" and "}"
    # and the CLILogger calls str.format().
    cli_logger.print(pprint.pformat(client.list_jobs()), no_format=True)
Exemple #12
0
def run(
    config_or_import_path: str,
    runtime_env: str,
    runtime_env_json: str,
    working_dir: str,
    app_dir: str,
    address: str,
    host: str,
    port: int,
    blocking: bool,
):
    sys.path.insert(0, app_dir)

    final_runtime_env = parse_runtime_env_args(
        runtime_env=runtime_env,
        runtime_env_json=runtime_env_json,
        working_dir=working_dir,
    )

    if pathlib.Path(config_or_import_path).is_file():
        config_path = config_or_import_path
        cli_logger.print(f'Deploying from config file: "{config_path}".')

        with open(config_path, "r") as config_file:
            config = ServeApplicationSchema.parse_obj(
                yaml.safe_load(config_file))
        is_config = True
    else:
        import_path = config_or_import_path
        cli_logger.print(f'Deploying from import path: "{import_path}".')
        node = import_attr(import_path)
        is_config = False

    # Setting the runtime_env here will set defaults for the deployments.
    ray.init(address=address,
             namespace=SERVE_NAMESPACE,
             runtime_env=final_runtime_env)
    client = serve.start(detached=True)

    try:
        if is_config:
            client.deploy_app(config)
        else:
            serve.run(node, host=host, port=port)
        cli_logger.success("Deployed successfully.")

        if blocking:
            while True:
                # Block, letting Ray print logs to the terminal.
                time.sleep(10)

    except KeyboardInterrupt:
        cli_logger.info("Got KeyboardInterrupt, shutting down...")
        serve.shutdown()
        sys.exit()
Exemple #13
0
    def wait_ready(self, deadline):
        with cli_logger.group("Waiting for SSH to become available",
                              _numbered=("[]", 1, 6)):
            with LogTimer(self.log_prefix + "Got remote shell"):
                cli_logger.old_info(logger, "{}Waiting for remote shell...",
                                    self.log_prefix)

                cli_logger.print("Running `{}` as a test.", cf.bold("uptime"))
                first_conn_refused_time = None
                while time.time() < deadline and \
                        not self.provider.is_terminated(self.node_id):
                    try:
                        cli_logger.old_debug(logger,
                                             "{}Waiting for remote shell...",
                                             self.log_prefix)

                        # Run outside of the container
                        self.cmd_runner.run("uptime",
                                            timeout=5,
                                            run_env="host")
                        cli_logger.old_debug(logger, "Uptime succeeded.")
                        cli_logger.success("Success.")
                        return True
                    except ProcessRunnerError as e:
                        first_conn_refused_time = \
                            cmd_output_util.handle_ssh_fails(
                                e, first_conn_refused_time,
                                retry_interval=READY_CHECK_INTERVAL)
                        time.sleep(READY_CHECK_INTERVAL)
                    except Exception as e:
                        # TODO(maximsmol): we should not be ignoring
                        # exceptions if they get filtered properly
                        # (new style log + non-interactive shells)
                        #
                        # however threading this configuration state
                        # is a pain and I'm leaving it for later

                        retry_str = str(e)
                        if hasattr(e, "cmd"):
                            retry_str = "(Exit Status {}): {}".format(
                                e.returncode, " ".join(e.cmd))

                        cli_logger.print(
                            "SSH still not available {}, "
                            "retrying in {} seconds.", cf.dimmed(retry_str),
                            cf.bold(str(READY_CHECK_INTERVAL)))
                        cli_logger.old_debug(logger,
                                             "{}Node not up, retrying: {}",
                                             self.log_prefix, retry_str)

                        time.sleep(READY_CHECK_INTERVAL)

        assert False, "Unable to connect to node"
Exemple #14
0
    def __exit__(self, *error_vals):
        if cli_logger.log_style != "record":
            return

        td = datetime.datetime.utcnow() - self._start_time
        status = ""
        if self._show_status:
            status = "failed" if any(error_vals) else "succeeded"
        cli_logger.print(" ".join([
            self._message, status,
            "[LogTimer={:.0f}ms]".format(td.total_seconds() * 1000)
        ]))
Exemple #15
0
def handle_ssh_fails(e, first_conn_refused_time, retry_interval):
    """Handle SSH system failures coming from a subprocess.

    Args:
        e: The `ProcessRunnerException` to handle.
        first_conn_refused_time:
            The time (as reported by this function) or None,
            indicating the last time a CONN_REFUSED error was caught.

            After exceeding a patience value, the program will be aborted
            since SSH will likely never recover.
        retry_interval: The interval after which the command will be retried,
                        used here just to inform the user.
    """
    if e.msg_type != "ssh_command_failed":
        return

    if e.special_case == "ssh_conn_refused":
        if (
            first_conn_refused_time is not None
            and time.time() - first_conn_refused_time > CONN_REFUSED_PATIENCE
        ):
            cli_logger.error(
                "SSH connection was being refused "
                "for {} seconds. Head node assumed "
                "unreachable.",
                cf.bold(str(CONN_REFUSED_PATIENCE)),
            )
            cli_logger.abort(
                "Check the node's firewall settings "
                "and the cloud network configuration."
            )

        cli_logger.warning("SSH connection was refused.")
        cli_logger.warning(
            "This might mean that the SSH daemon is "
            "still setting up, or that "
            "the host is inaccessable (e.g. due to "
            "a firewall)."
        )

        return time.time()

    if e.special_case in ["ssh_timeout", "ssh_conn_refused"]:
        cli_logger.print(
            "SSH still not available, " "retrying in {} seconds.",
            cf.bold(str(retry_interval)),
        )
    else:
        raise e

    return first_conn_refused_time
Exemple #16
0
def run(
    config_or_import_path: str,
    runtime_env: str,
    runtime_env_json: str,
    working_dir: str,
    app_dir: str,
    address: str,
    host: str,
    port: int,
    blocking: bool,
):
    sys.path.insert(0, app_dir)

    final_runtime_env = parse_runtime_env_args(
        runtime_env=runtime_env,
        runtime_env_json=runtime_env_json,
        working_dir=working_dir,
    )

    app_or_node = None
    if pathlib.Path(config_or_import_path).is_file():
        config_path = config_or_import_path
        cli_logger.print(f"Loading app from config file: '{config_path}'.")
        with open(config_path, "r") as config_file:
            app_or_node = Application.from_yaml(config_file)
    else:
        import_path = config_or_import_path
        cli_logger.print(f"Loading app from import path: '{import_path}'.")
        app_or_node = import_attr(import_path)

    # Setting the runtime_env here will set defaults for the deployments.
    ray.init(address=address, namespace="serve", runtime_env=final_runtime_env)

    try:
        serve.run(app_or_node, host=host, port=port)
        cli_logger.success("Deployed successfully!\n")

        if blocking:
            while True:
                statuses = serve_application_status_to_schema(
                    get_deployment_statuses()
                ).json(indent=4)
                cli_logger.info(f"{statuses}")
                time.sleep(10)

    except KeyboardInterrupt:
        cli_logger.info("Got KeyboardInterrupt, shutting down...")
        serve.shutdown()
        sys.exit()
Exemple #17
0
def kill_node(config_file, yes, hard, override_cluster_name):
    """Kills a random Raylet worker."""

    config = yaml.safe_load(open(config_file).read())
    if override_cluster_name is not None:
        config["cluster_name"] = override_cluster_name
    config = _bootstrap_config(config)

    cli_logger.confirm(yes, "A random node will be killed.")
    cli_logger.old_confirm("This will kill a node in your cluster", yes)

    provider = _get_node_provider(config["provider"], config["cluster_name"])
    try:
        nodes = provider.non_terminated_nodes({
            TAG_RAY_NODE_KIND: NODE_KIND_WORKER
        })
        node = random.choice(nodes)
        cli_logger.print("Shutdown " + cf.bold("{}"), node)
        cli_logger.old_info(logger, "kill_node: Shutdown worker {}", node)
        if hard:
            provider.terminate_node(node)
        else:
            updater = NodeUpdaterThread(
                node_id=node,
                provider_config=config["provider"],
                provider=provider,
                auth_config=config["auth"],
                cluster_name=config["cluster_name"],
                file_mounts=config["file_mounts"],
                initialization_commands=[],
                setup_commands=[],
                ray_start_commands=[],
                runtime_hash="",
                file_mounts_contents_hash="",
                is_head_node=False,
                docker_config=config.get("docker"))

            _exec(updater, "ray stop", False, False)

        time.sleep(POLL_INTERVAL)

        if config.get("provider", {}).get("use_internal_ips", False) is True:
            node_ip = provider.internal_ip(node)
        else:
            node_ip = provider.external_ip(node)
    finally:
        provider.cleanup()

    return node_ip
Exemple #18
0
    def fillout_available_node_types_resources(
            cluster_config: Dict[str, Any]) -> Dict[str, Any]:
        """Fills out missing "resources" field for available_node_types."""
        if "available_node_types" not in cluster_config:
            return cluster_config
        cluster_config = copy.deepcopy(cluster_config)

        instances_list = list_ec2_instances(
            cluster_config["provider"]["region"])
        instances_dict = {
            instance["InstanceType"]: instance
            for instance in instances_list
        }
        available_node_types = cluster_config["available_node_types"]
        for node_type in available_node_types:
            instance_type = available_node_types[node_type]["node_config"][
                "InstanceType"]
            if instance_type in instances_dict:
                cpus = instances_dict[instance_type]["VCpuInfo"][
                    "DefaultVCpus"]
                autodetected_resources = {"CPU": cpus}
                gpus = instances_dict[instance_type].get("GpuInfo",
                                                         {}).get("Gpus")
                if gpus is not None:
                    # TODO(ameer): currently we support one gpu type per node.
                    assert len(gpus) == 1
                    gpu_name = gpus[0]["Name"]
                    autodetected_resources.update({
                        "GPU":
                        gpus[0]["Count"],
                        f"accelerator_type:{gpu_name}":
                        1
                    })
                autodetected_resources.update(
                    available_node_types[node_type].get("resources", {}))
                if autodetected_resources != \
                        available_node_types[node_type].get("resources", {}):
                    available_node_types[node_type][
                        "resources"] = autodetected_resources
                    cli_logger.print("Updating the resources of {} to {}.",
                                     node_type, autodetected_resources)
            else:
                raise ValueError("Instance type " + instance_type +
                                 " is not available in AWS region: " +
                                 cluster_config["provider"]["region"] + ".")
        return cluster_config
Exemple #19
0
    def terminate_nodes(self, node_ids):
        if not node_ids:
            return
        if self.cache_stopped_nodes:
            spot_ids = []
            on_demand_ids = []

            for node_id in node_ids:
                if self._get_cached_node(node_id).spot_instance_request_id:
                    spot_ids += [node_id]
                else:
                    on_demand_ids += [node_id]

            if on_demand_ids:
                # todo: show node names?
                cli_logger.print(
                    "Stopping instances {} " + cf.dimmed(
                        "(to terminate instead, "
                        "set `cache_stopped_nodes: False` "
                        "under `provider` in the cluster configuration)"),
                    cli_logger.render_list(on_demand_ids))
                cli_logger.old_info(
                    logger,
                    "AWSNodeProvider: stopping nodes {}. To terminate nodes "
                    "on stop, set 'cache_stopped_nodes: False' in the "
                    "provider config.", on_demand_ids)

                self.ec2.meta.client.stop_instances(InstanceIds=on_demand_ids)
            if spot_ids:
                cli_logger.print(
                    "Terminating instances {} " +
                    cf.dimmed("(cannot stop spot instances, only terminate)"),
                    cli_logger.render_list(spot_ids))
                cli_logger.old_info(
                    logger,
                    "AWSNodeProvider: terminating nodes {} (spot nodes cannot "
                    "be stopped, only terminated)", spot_ids)

                self.ec2.meta.client.terminate_instances(InstanceIds=spot_ids)
        else:
            self.ec2.meta.client.terminate_instances(InstanceIds=node_ids)

        for node_id in node_ids:
            self.tag_cache.pop(node_id, None)
            self.tag_cache_pending.pop(node_id, None)
Exemple #20
0
def get_info_from_ray_cluster_config(
        cluster_config: str
) -> Tuple[List[str], str, str, Optional[str], Optional[str]]:
    """Get information from Ray cluster config.

    Return list of host IPs, ssh user, ssh key file, and optional docker
    container.

    Args:
        cluster_config (str): Path to ray cluster config.

    Returns:
        Tuple of list of host IPs, ssh user name, ssh key file path,
            optional docker container name, optional cluster name.
    """
    from ray.autoscaler._private.commands import _bootstrap_config

    cli_logger.print(f"Retrieving cluster information from ray cluster file: "
                     f"{cluster_config}")

    cluster_config = os.path.expanduser(cluster_config)

    config = yaml.safe_load(open(cluster_config).read())
    config = _bootstrap_config(config, no_config_cache=True)

    provider = _get_node_provider(config["provider"], config["cluster_name"])
    head_nodes = provider.non_terminated_nodes({
        TAG_RAY_NODE_KIND: NODE_KIND_HEAD
    })
    worker_nodes = provider.non_terminated_nodes({
        TAG_RAY_NODE_KIND: NODE_KIND_WORKER
    })

    hosts = [provider.external_ip(node) for node in head_nodes + worker_nodes]
    ssh_user = config["auth"]["ssh_user"]
    ssh_key = config["auth"]["ssh_private_key"]

    docker = None
    docker_config = config.get("docker", None)
    if docker_config:
        docker = docker_config.get("container_name", None)

    cluster_name = config.get("cluster_name", None)

    return hosts, ssh_user, ssh_key, docker, cluster_name
Exemple #21
0
    def _wait_for_ip(self, deadline):
        # if we have IP do not print waiting info
        ip = self._get_node_ip()
        if ip is not None:
            cli_logger.labeled_value("Fetched IP", ip)
            return ip

        interval = 10
        with cli_logger.group("Waiting for IP"):
            while time.time() < deadline and \
                    not self.provider.is_terminated(self.node_id):
                ip = self._get_node_ip()
                if ip is not None:
                    cli_logger.labeled_value("Received", ip)
                    return ip
                cli_logger.print("Not yet available, retrying in {} seconds",
                                 cf.bold(str(interval)))
                time.sleep(interval)

        return None
Exemple #22
0
    def terminate_node(self, node_id):
        node = self._get_cached_node(node_id)
        if self.cache_stopped_nodes:
            if node.spot_instance_request_id:
                cli_logger.print(
                    "Terminating instance {} " +
                    cf.dimmed("(cannot stop spot instances, only terminate)"),
                    node_id)  # todo: show node name?
                node.terminate()
            else:
                cli_logger.print("Stopping instance {} " + cf.dimmed(
                    "(to terminate instead, "
                    "set `cache_stopped_nodes: False` "
                    "under `provider` in the cluster configuration)"),
                                 node_id)  # todo: show node name?
                node.stop()
        else:
            node.terminate()

        self.tag_cache.pop(node_id, None)
        self.tag_cache_pending.pop(node_id, None)
Exemple #23
0
def stop(address: Optional[str], no_wait: bool, job_id: str):
    """Attempts to stop a job.

    Example:
        ray job stop <my_job_id>
    """
    client = _get_sdk_client(address)
    cli_logger.print(f"Attempting to stop job {job_id}")
    client.stop_job(job_id)

    if no_wait:
        return
    else:
        cli_logger.print(f"Waiting for job '{job_id}' to exit "
                         f"(disable with --no-wait):")

    while True:
        status = client.get_job_status(job_id)
        if status in {
                JobStatus.STOPPED, JobStatus.SUCCEEDED, JobStatus.FAILED
        }:
            _log_job_status(client, job_id)
            break
        else:
            cli_logger.print(f"Job has not exited yet. Status: {status}")
            time.sleep(1)
Exemple #24
0
def logs(address: Optional[str], job_id: str, follow: bool):
    """Gets the logs of a job.

    Example:
        ray job logs <my_job_id>
    """
    client = _get_sdk_client(address)
    sdk_version = client.get_version()
    # sdk version 0 did not have log streaming
    if follow:
        if int(sdk_version) > 0:
            asyncio.get_event_loop().run_until_complete(
                _tail_logs(client, job_id))
        else:
            cli_logger.warning(
                "Tailing logs is not enabled for job sdk client version "
                f"{sdk_version}. Please upgrade your ray to latest version "
                "for this feature.")
    else:
        # Set no_format to True because the logs may have unescaped "{" and "}"
        # and the CLILogger calls str.format().
        cli_logger.print(client.get_job_logs(job_id), end="", no_format=True)
Exemple #25
0
def show_usage_stats_prompt() -> None:
    if not usage_stats_prompt_enabled():
        return

    from ray.autoscaler._private.cli_logger import cli_logger

    usage_stats_enabledness = _usage_stats_enabledness()
    if usage_stats_enabledness is UsageStatsEnabledness.DISABLED_EXPLICITLY:
        cli_logger.print(usage_constant.USAGE_STATS_DISABLED_MESSAGE)
    elif usage_stats_enabledness is UsageStatsEnabledness.ENABLED_BY_DEFAULT:

        if cli_logger.interactive:
            enabled = cli_logger.confirm(
                False,
                usage_constant.USAGE_STATS_CONFIRMATION_MESSAGE,
                _default=True,
                _timeout_s=10,
            )
            set_usage_stats_enabled_via_env_var(enabled)
            # Remember user's choice.
            try:
                set_usage_stats_enabled_via_config(enabled)
            except Exception as e:
                logger.debug(
                    f"Failed to persist usage stats choice for future clusters: {e}"
                )
            if enabled:
                cli_logger.print(usage_constant.USAGE_STATS_ENABLED_MESSAGE)
            else:
                cli_logger.print(usage_constant.USAGE_STATS_DISABLED_MESSAGE)
        else:
            cli_logger.print(
                usage_constant.USAGE_STATS_ENABLED_BY_DEFAULT_MESSAGE, )
    else:
        assert usage_stats_enabledness is UsageStatsEnabledness.ENABLED_EXPLICITLY
        cli_logger.print(usage_constant.USAGE_STATS_ENABLED_MESSAGE)
Exemple #26
0
def _log_job_status(client: JobSubmissionClient, job_id: str):
    info = client.get_job_info(job_id)
    if info.status == JobStatus.SUCCEEDED:
        _log_big_success_msg(f"Job '{job_id}' succeeded")
    elif info.status == JobStatus.STOPPED:
        cli_logger.warning(f"Job '{job_id}' was stopped")
    elif info.status == JobStatus.FAILED:
        _log_big_error_msg(f"Job '{job_id}' failed")
        if info.message is not None:
            cli_logger.print(f"Status message: {info.message}", no_format=True)
    else:
        # Catch-all.
        cli_logger.print(f"Status for job '{job_id}': {info.status}")
        if info.message is not None:
            cli_logger.print(f"Status message: {info.message}", no_format=True)
Exemple #27
0
def submit(
    address: Optional[str],
    job_id: Optional[str],
    runtime_env: Optional[str],
    runtime_env_json: Optional[str],
    working_dir: Optional[str],
    entrypoint: Tuple[str],
    no_wait: bool,
):
    """Submits a job to be run on the cluster.

    Example:
        ray job submit -- python my_script.py --arg=val
    """
    client = _get_sdk_client(address, create_cluster_if_needed=True)

    final_runtime_env = parse_runtime_env_args(
        runtime_env=runtime_env,
        runtime_env_json=runtime_env_json,
        working_dir=working_dir,
    )

    job_id = client.submit_job(
        entrypoint=list2cmdline(entrypoint),
        job_id=job_id,
        runtime_env=final_runtime_env,
    )

    _log_big_success_msg(f"Job '{job_id}' submitted successfully")

    with cli_logger.group("Next steps"):
        cli_logger.print("Query the logs of the job:")
        with cli_logger.indented():
            cli_logger.print(cf.bold(f"ray job logs {job_id}"))

        cli_logger.print("Query the status of the job:")
        with cli_logger.indented():
            cli_logger.print(cf.bold(f"ray job status {job_id}"))

        cli_logger.print("Request the job to be stopped:")
        with cli_logger.indented():
            cli_logger.print(cf.bold(f"ray job stop {job_id}"))

    cli_logger.newline()
    sdk_version = client.get_version()
    # sdk version 0 does not have log streaming
    if not no_wait:
        if int(sdk_version) > 0:
            cli_logger.print("Tailing logs until the job exits "
                             "(disable with --no-wait):")
            asyncio.get_event_loop().run_until_complete(
                _tail_logs(client, job_id))
        else:
            cli_logger.warning(
                "Tailing logs is not enabled for job sdk client version "
                f"{sdk_version}. Please upgrade your ray to latest version "
                "for this feature.")
Exemple #28
0
def exec_cluster(config_file: str,
                 *,
                 cmd: str = None,
                 run_env: str = "auto",
                 screen: bool = False,
                 tmux: bool = False,
                 stop: bool = False,
                 start: bool = False,
                 override_cluster_name: Optional[str] = None,
                 no_config_cache: bool = False,
                 port_forward: Any = None,
                 with_output: bool = False):
    """Runs a command on the specified cluster.

    Arguments:
        config_file: path to the cluster yaml
        cmd: command to run
        run_env: whether to run the command on the host or in a container.
            Select between "auto", "host" and "docker"
        screen: whether to run in a screen
        tmux: whether to run in a tmux session
        stop: whether to stop the cluster after command run
        start: whether to start the cluster if it isn't up
        override_cluster_name: set the name of the cluster
        port_forward (int or list[int]): port(s) to forward
    """
    assert not (screen and tmux), "Can specify only one of `screen` or `tmux`."
    assert run_env in RUN_ENV_TYPES, "--run_env must be in {}".format(
        RUN_ENV_TYPES)
    # TODO(rliaw): We default this to True to maintain backwards-compat.
    # In the future we would want to support disabling login-shells
    # and interactivity.
    cmd_output_util.set_allow_interactive(True)

    config = yaml.safe_load(open(config_file).read())
    if override_cluster_name is not None:
        config["cluster_name"] = override_cluster_name
    config = _bootstrap_config(config, no_config_cache=no_config_cache)

    head_node = _get_head_node(config,
                               config_file,
                               override_cluster_name,
                               create_if_needed=start)

    provider = _get_node_provider(config["provider"], config["cluster_name"])
    try:
        updater = NodeUpdaterThread(node_id=head_node,
                                    provider_config=config["provider"],
                                    provider=provider,
                                    auth_config=config["auth"],
                                    cluster_name=config["cluster_name"],
                                    file_mounts=config["file_mounts"],
                                    initialization_commands=[],
                                    setup_commands=[],
                                    ray_start_commands=[],
                                    runtime_hash="",
                                    file_mounts_contents_hash="",
                                    is_head_node=True,
                                    docker_config=config.get("docker"))
        shutdown_after_run = False
        if cmd and stop:
            cmd += "; ".join([
                "ray stop",
                "ray teardown ~/ray_bootstrap_config.yaml --yes --workers-only"
            ])
            shutdown_after_run = True

        result = _exec(updater,
                       cmd,
                       screen,
                       tmux,
                       port_forward=port_forward,
                       with_output=with_output,
                       run_env=run_env,
                       shutdown_after_run=shutdown_after_run)
        if tmux or screen:
            attach_command_parts = ["ray attach", config_file]
            if override_cluster_name is not None:
                attach_command_parts.append(
                    "--cluster-name={}".format(override_cluster_name))
            if tmux:
                attach_command_parts.append("--tmux")
            elif screen:
                attach_command_parts.append("--screen")

            attach_command = " ".join(attach_command_parts)
            cli_logger.print("Run `{}` to check command status.",
                             cf.bold(attach_command))

            attach_info = "Use `{}` to check on command status.".format(
                attach_command)
            cli_logger.old_info(logger, attach_info)
        return result
    finally:
        provider.cleanup()
Exemple #29
0
def get_or_create_head_node(config,
                            config_file,
                            no_restart,
                            restart_only,
                            yes,
                            override_cluster_name,
                            _provider=None,
                            _runner=subprocess):
    """Create the cluster head node, which in turn creates the workers."""
    provider = (_provider or _get_node_provider(config["provider"],
                                                config["cluster_name"]))

    config = copy.deepcopy(config)
    config_file = os.path.abspath(config_file)
    try:
        head_node_tags = {
            TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
        }
        nodes = provider.non_terminated_nodes(head_node_tags)
        if len(nodes) > 0:
            head_node = nodes[0]
        else:
            head_node = None

        if not head_node:
            cli_logger.confirm(yes, "No head node found. "
                               "Launching a new cluster.",
                               _abort=True)
            cli_logger.old_confirm("This will create a new cluster", yes)
        elif not no_restart:
            cli_logger.old_confirm("This will restart cluster services", yes)

        if head_node:
            if restart_only:
                cli_logger.confirm(
                    yes, "Updating cluster configuration and "
                    "restarting the cluster Ray runtime. "
                    "Setup commands will not be run due to `{}`.\n",
                    cf.bold("--restart-only"),
                    _abort=True)
            elif no_restart:
                cli_logger.print(
                    "Cluster Ray runtime will not be restarted due "
                    "to `{}`.", cf.bold("--no-restart"))
                cli_logger.confirm(yes, "Updating cluster configuration and "
                                   "running setup commands.",
                                   _abort=True)
            else:
                cli_logger.print(
                    "Updating cluster configuration and running full setup.")
                cli_logger.confirm(
                    yes,
                    cf.bold("Cluster Ray runtime will be restarted."),
                    _abort=True)
        cli_logger.newline()

        # TODO(ekl) this logic is duplicated in node_launcher.py (keep in sync)
        head_node_config = copy.deepcopy(config["head_node"])
        if "head_node_type" in config:
            head_node_tags[TAG_RAY_USER_NODE_TYPE] = config["head_node_type"]
            head_node_config.update(config["available_node_types"][
                config["head_node_type"]]["node_config"])

        launch_hash = hash_launch_conf(head_node_config, config["auth"])
        if head_node is None or provider.node_tags(head_node).get(
                TAG_RAY_LAUNCH_CONFIG) != launch_hash:
            with cli_logger.group("Acquiring an up-to-date head node"):
                if head_node is not None:
                    cli_logger.print(
                        "Currently running head node is out-of-date with "
                        "cluster configuration")
                    cli_logger.print(
                        "hash is {}, expected {}",
                        cf.bold(
                            provider.node_tags(head_node).get(
                                TAG_RAY_LAUNCH_CONFIG)), cf.bold(launch_hash))
                    cli_logger.confirm(yes, "Relaunching it.", _abort=True)
                    cli_logger.old_confirm(
                        "Head node config out-of-date. It will be terminated",
                        yes)

                    cli_logger.old_info(
                        logger, "get_or_create_head_node: "
                        "Shutting down outdated head node {}", head_node)

                    provider.terminate_node(head_node)
                    cli_logger.print("Terminated head node {}", head_node)

                cli_logger.old_info(
                    logger,
                    "get_or_create_head_node: Launching new head node...")

                head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
                head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
                    config["cluster_name"])
                provider.create_node(head_node_config, head_node_tags, 1)
                cli_logger.print("Launched a new head node")

                start = time.time()
                head_node = None
                with cli_logger.timed("Fetching the new head node"):
                    while True:
                        if time.time() - start > 50:
                            cli_logger.abort(
                                "Head node fetch timed out.")  # todo: msg
                            raise RuntimeError("Failed to create head node.")
                        nodes = provider.non_terminated_nodes(head_node_tags)
                        if len(nodes) == 1:
                            head_node = nodes[0]
                            break
                        time.sleep(POLL_INTERVAL)
                cli_logger.newline()

        with cli_logger.group(
                "Setting up head node",
                _numbered=("<>", 1, 1),
                # cf.bold(provider.node_tags(head_node)[TAG_RAY_NODE_NAME]),
                _tags=dict()):  # add id, ARN to tags?

            # TODO(ekl) right now we always update the head node even if the
            # hash matches.
            # We could prompt the user for what they want to do here.
            # No need to pass in cluster_sync_files because we use this
            # hash to set up the head node
            (runtime_hash, file_mounts_contents_hash) = hash_runtime_conf(
                config["file_mounts"], None, config)

            cli_logger.old_info(
                logger,
                "get_or_create_head_node: Updating files on head node...")

            # Rewrite the auth config so that the head
            # node can update the workers
            remote_config = copy.deepcopy(config)

            # drop proxy options if they exist, otherwise
            # head node won't be able to connect to workers
            remote_config["auth"].pop("ssh_proxy_command", None)

            if "ssh_private_key" in config["auth"]:
                remote_key_path = "~/ray_bootstrap_key.pem"
                remote_config["auth"]["ssh_private_key"] = remote_key_path

            # Adjust for new file locations
            new_mounts = {}
            for remote_path in config["file_mounts"]:
                new_mounts[remote_path] = remote_path
            remote_config["file_mounts"] = new_mounts
            remote_config["no_restart"] = no_restart

            remote_config = provider.prepare_for_head_node(remote_config)

            # Now inject the rewritten config and SSH key into the head node
            remote_config_file = tempfile.NamedTemporaryFile(
                "w", prefix="ray-bootstrap-")
            remote_config_file.write(json.dumps(remote_config))
            remote_config_file.flush()
            config["file_mounts"].update(
                {"~/ray_bootstrap_config.yaml": remote_config_file.name})

            if "ssh_private_key" in config["auth"]:
                config["file_mounts"].update({
                    remote_key_path:
                    config["auth"]["ssh_private_key"],
                })
            cli_logger.print("Prepared bootstrap config")

            if restart_only:
                setup_commands = []
                ray_start_commands = config["head_start_ray_commands"]
            elif no_restart:
                setup_commands = config["head_setup_commands"]
                ray_start_commands = []
            else:
                setup_commands = config["head_setup_commands"]
                ray_start_commands = config["head_start_ray_commands"]

            if not no_restart:
                warn_about_bad_start_command(ray_start_commands)

            updater = NodeUpdaterThread(
                node_id=head_node,
                provider_config=config["provider"],
                provider=provider,
                auth_config=config["auth"],
                cluster_name=config["cluster_name"],
                file_mounts=config["file_mounts"],
                initialization_commands=config["initialization_commands"],
                setup_commands=setup_commands,
                ray_start_commands=ray_start_commands,
                process_runner=_runner,
                runtime_hash=runtime_hash,
                file_mounts_contents_hash=file_mounts_contents_hash,
                is_head_node=True,
                docker_config=config.get("docker"))
            updater.start()
            updater.join()

            # Refresh the node cache so we see the external ip if available
            provider.non_terminated_nodes(head_node_tags)

            if config.get("provider", {}).get("use_internal_ips",
                                              False) is True:
                head_node_ip = provider.internal_ip(head_node)
            else:
                head_node_ip = provider.external_ip(head_node)

            if updater.exitcode != 0:
                # todo: this does not follow the mockup and is not good enough
                cli_logger.abort("Failed to setup head node.")

                cli_logger.old_error(
                    logger, "get_or_create_head_node: "
                    "Updating {} failed", head_node_ip)
                sys.exit(1)

            cli_logger.old_info(
                logger, "get_or_create_head_node: "
                "Head node up-to-date, IP address is: {}", head_node_ip)

        monitor_str = "tail -n 100 -f /tmp/ray/session_latest/logs/monitor*"
        if override_cluster_name:
            modifiers = " --cluster-name={}".format(
                quote(override_cluster_name))
        else:
            modifiers = ""

        if cli_logger.old_style:
            print("To monitor autoscaling activity, you can run:\n\n"
                  "  ray exec {} {}{}\n".format(config_file,
                                                quote(monitor_str), modifiers))
            print("To open a console on the cluster:\n\n"
                  "  ray attach {}{}\n".format(config_file, modifiers))

            print("To get a remote shell to the cluster manually, run:\n\n"
                  "  {}\n".format(
                      updater.cmd_runner.remote_shell_command_str()))

        cli_logger.newline()
        with cli_logger.group("Useful commands"):
            cli_logger.print("Monitor autoscaling with")
            cli_logger.print(cf.bold("  ray exec {}{} {}"), config_file,
                             modifiers, quote(monitor_str))

            cli_logger.print("Connect to a terminal on the cluster head:")
            cli_logger.print(cf.bold("  ray attach {}{}"), config_file,
                             modifiers)

            remote_shell_str = updater.cmd_runner.remote_shell_command_str()
            cli_logger.print("Get a remote shell to the cluster manually:")
            cli_logger.print("  {}", remote_shell_str.strip())
    finally:
        provider.cleanup()
Exemple #30
0
def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
                     override_cluster_name: Optional[str],
                     keep_min_workers: bool):
    """Destroys all nodes of a Ray cluster described by a config json."""
    config = yaml.safe_load(open(config_file).read())
    if override_cluster_name is not None:
        config["cluster_name"] = override_cluster_name
    config = prepare_config(config)
    validate_config(config)

    cli_logger.confirm(yes, "Destroying cluster.", _abort=True)
    cli_logger.old_confirm("This will destroy your cluster", yes)

    if not workers_only:
        try:
            exec_cluster(config_file,
                         cmd="ray stop",
                         run_env="auto",
                         screen=False,
                         tmux=False,
                         stop=False,
                         start=False,
                         override_cluster_name=override_cluster_name,
                         port_forward=None,
                         with_output=False)
        except Exception as e:
            # todo: add better exception info
            cli_logger.verbose_error("{}", str(e))
            cli_logger.warning(
                "Exception occured when stopping the cluster Ray runtime "
                "(use -v to dump teardown exceptions).")
            cli_logger.warning(
                "Ignoring the exception and "
                "attempting to shut down the cluster nodes anyway.")

            cli_logger.old_exception(
                logger, "Ignoring error attempting a clean shutdown.")

    provider = _get_node_provider(config["provider"], config["cluster_name"])
    try:

        def remaining_nodes():
            workers = provider.non_terminated_nodes(
                {TAG_RAY_NODE_KIND: NODE_KIND_WORKER})

            if keep_min_workers:
                min_workers = config.get("min_workers", 0)

                cli_logger.print(
                    "{} random worker nodes will not be shut down. " +
                    cf.dimmed("(due to {})"), cf.bold(min_workers),
                    cf.bold("--keep-min-workers"))
                cli_logger.old_info(logger,
                                    "teardown_cluster: Keeping {} nodes...",
                                    min_workers)

                workers = random.sample(workers, len(workers) - min_workers)

            # todo: it's weird to kill the head node but not all workers
            if workers_only:
                cli_logger.print(
                    "The head node will not be shut down. " +
                    cf.dimmed("(due to {})"), cf.bold("--workers-only"))

                return workers

            head = provider.non_terminated_nodes(
                {TAG_RAY_NODE_KIND: NODE_KIND_HEAD})

            return head + workers

        def run_docker_stop(node, container_name):
            try:
                updater = NodeUpdaterThread(
                    node_id=node,
                    provider_config=config["provider"],
                    provider=provider,
                    auth_config=config["auth"],
                    cluster_name=config["cluster_name"],
                    file_mounts=config["file_mounts"],
                    initialization_commands=[],
                    setup_commands=[],
                    ray_start_commands=[],
                    runtime_hash="",
                    file_mounts_contents_hash="",
                    is_head_node=False,
                    docker_config=config.get("docker"))
                _exec(updater,
                      f"docker stop {container_name}",
                      False,
                      False,
                      run_env="host")
            except Exception:
                cli_logger.warning(f"Docker stop failed on {node}")
                cli_logger.old_warning(logger, f"Docker stop failed on {node}")

        # Loop here to check that both the head and worker nodes are actually
        #   really gone
        A = remaining_nodes()

        container_name = config.get("docker", {}).get("container_name")
        if container_name:
            for node in A:
                run_docker_stop(node, container_name)

        with LogTimer("teardown_cluster: done."):
            while A:
                cli_logger.old_info(
                    logger, "teardown_cluster: "
                    "Shutting down {} nodes...", len(A))

                provider.terminate_nodes(A)

                cli_logger.print("Requested {} nodes to shut down.",
                                 cf.bold(len(A)),
                                 _tags=dict(interval="1s"))

                time.sleep(
                    POLL_INTERVAL)  # todo: interval should be a variable
                A = remaining_nodes()
                cli_logger.print("{} nodes remaining after {} second(s).",
                                 cf.bold(len(A)), POLL_INTERVAL)
            cli_logger.success("No nodes remaining.")
    finally:
        provider.cleanup()