Example #1
0
def create_or_update_cluster(config_file: str,
                             override_min_workers: Optional[int],
                             override_max_workers: Optional[int],
                             no_restart: bool,
                             restart_only: bool,
                             yes: bool,
                             override_cluster_name: Optional[str],
                             no_config_cache: bool = False,
                             redirect_command_output: bool = False,
                             use_login_shells: bool = True) -> None:
    """Create or updates an autoscaling Ray cluster from a config json."""
    set_using_login_shells(use_login_shells)
    if not use_login_shells:
        cmd_output_util.set_allow_interactive(False)
    if redirect_command_output is None:
        # Do not redirect by default.
        cmd_output_util.set_output_redirected(False)
    else:
        cmd_output_util.set_output_redirected(redirect_command_output)

    if use_login_shells:
        cli_logger.warning(
            "Commands running under a login shell can produce more "
            "output than special processing can handle.")
        cli_logger.warning(
            "Thus, the output from subcommands will be logged as is.")
        cli_logger.warning(
            "Consider using {}, {}.", cf.bold("--use-normal-shells"),
            cf.underlined("if you tested your workflow and it is compatible"))
        cli_logger.newline()

    def handle_yaml_error(e):
        cli_logger.error("Cluster config invalid")
        cli_logger.newline()
        cli_logger.error("Failed to load YAML file " + cf.bold("{}"),
                         config_file)
        cli_logger.newline()
        with cli_logger.verbatim_error_ctx("PyYAML error:"):
            cli_logger.error(e)
        cli_logger.abort()

    try:
        config = yaml.safe_load(open(config_file).read())
    except FileNotFoundError:
        cli_logger.abort(
            "Provided cluster configuration file ({}) does not exist",
            cf.bold(config_file))
        raise
    except yaml.parser.ParserError as e:
        handle_yaml_error(e)
        raise
    except yaml.scanner.ScannerError as e:
        handle_yaml_error(e)
        raise

    # todo: validate file_mounts, ssh keys, etc.

    importer = NODE_PROVIDERS.get(config["provider"]["type"])
    if not importer:
        cli_logger.abort(
            "Unknown provider type " + cf.bold("{}") + "\n"
            "Available providers are: {}", config["provider"]["type"],
            cli_logger.render_list([
                k for k in NODE_PROVIDERS.keys()
                if NODE_PROVIDERS[k] is not None
            ]))
        raise NotImplementedError("Unsupported provider {}".format(
            config["provider"]))

    cli_logger.success("Cluster configuration valid")

    printed_overrides = False

    def handle_cli_override(key, override):
        if override is not None:
            if key in config:
                nonlocal printed_overrides
                printed_overrides = True
                cli_logger.warning(
                    "`{}` override provided on the command line.\n"
                    "  Using " + cf.bold("{}") +
                    cf.dimmed(" [configuration file has " + cf.bold("{}") +
                              "]"), key, override, config[key])
            config[key] = override

    handle_cli_override("min_workers", override_min_workers)
    handle_cli_override("max_workers", override_max_workers)
    handle_cli_override("cluster_name", override_cluster_name)

    if printed_overrides:
        cli_logger.newline()

    cli_logger.labeled_value("Cluster", config["cluster_name"])

    # disable the cli_logger here if needed
    # because it only supports aws
    if config["provider"]["type"] != "aws":
        cli_logger.old_style = True
    cli_logger.newline()
    config = _bootstrap_config(config, no_config_cache=no_config_cache)

    try_logging_config(config)
    get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
                            override_cluster_name)
Example #2
0
def exec_cluster(config_file: str,
                 *,
                 cmd: Any = None,
                 run_env: str = "auto",
                 screen: bool = False,
                 tmux: bool = False,
                 stop: bool = False,
                 start: bool = False,
                 override_cluster_name: Optional[str] = None,
                 no_config_cache: bool = False,
                 port_forward: Any = None,
                 with_output: bool = False):
    """Runs a command on the specified cluster.

    Arguments:
        config_file: path to the cluster yaml
        cmd: command to run
        run_env: whether to run the command on the host or in a container.
            Select between "auto", "host" and "docker"
        screen: whether to run in a screen
        tmux: whether to run in a tmux session
        stop: whether to stop the cluster after command run
        start: whether to start the cluster if it isn't up
        override_cluster_name: set the name of the cluster
        port_forward (int or list[int]): port(s) to forward
    """
    assert not (screen and tmux), "Can specify only one of `screen` or `tmux`."
    assert run_env in RUN_ENV_TYPES, "--run_env must be in {}".format(
        RUN_ENV_TYPES)
    # TODO(rliaw): We default this to True to maintain backwards-compat.
    # In the future we would want to support disabling login-shells
    # and interactivity.
    cmd_output_util.set_allow_interactive(True)

    config = yaml.safe_load(open(config_file).read())
    if override_cluster_name is not None:
        config["cluster_name"] = override_cluster_name
    config = _bootstrap_config(config, no_config_cache=no_config_cache)

    head_node = _get_head_node(config,
                               config_file,
                               override_cluster_name,
                               create_if_needed=start)

    provider = get_node_provider(config["provider"], config["cluster_name"])
    try:
        updater = NodeUpdaterThread(node_id=head_node,
                                    provider_config=config["provider"],
                                    provider=provider,
                                    auth_config=config["auth"],
                                    cluster_name=config["cluster_name"],
                                    file_mounts=config["file_mounts"],
                                    initialization_commands=[],
                                    setup_commands=[],
                                    ray_start_commands=[],
                                    runtime_hash="",
                                    file_mounts_contents_hash="",
                                    is_head_node=True,
                                    docker_config=config.get("docker"))

        is_docker = isinstance(updater.cmd_runner, DockerCommandRunner)

        if cmd and stop:
            cmd += "; ".join([
                "ray stop",
                "ray teardown ~/ray_bootstrap_config.yaml --yes --workers-only"
            ])
            if is_docker and run_env == "docker":
                updater.cmd_runner.shutdown_after_next_cmd()
            else:
                cmd += "; sudo shutdown -h now"

        result = _exec(updater,
                       cmd,
                       screen,
                       tmux,
                       port_forward=port_forward,
                       with_output=with_output,
                       run_env=run_env)
        if tmux or screen:
            attach_command_parts = ["ray attach", config_file]
            if override_cluster_name is not None:
                attach_command_parts.append(
                    "--cluster-name={}".format(override_cluster_name))
            if tmux:
                attach_command_parts.append("--tmux")
            elif screen:
                attach_command_parts.append("--screen")

            attach_command = " ".join(attach_command_parts)
            cli_logger.print("Run `{}` to check command status.",
                             cf.bold(attach_command))

            attach_info = "Use `{}` to check on command status.".format(
                attach_command)
            cli_logger.old_info(logger, attach_info)
        return result
    finally:
        provider.cleanup()