Пример #1
0
def submit(
    address: Optional[str],
    job_id: Optional[str],
    runtime_env: Optional[str],
    runtime_env_json: Optional[str],
    working_dir: Optional[str],
    entrypoint: Tuple[str],
    no_wait: bool,
):
    """Submits a job to be run on the cluster docstring

    Example:
        >>> ray job submit -- python my_script.py --arg=val
    """
    client = _get_sdk_client(address, create_cluster_if_needed=True)

    final_runtime_env = {}
    if runtime_env is not None:
        if runtime_env_json is not None:
            raise ValueError(
                "Only one of --runtime_env and " "--runtime-env-json can be provided."
            )
        with open(runtime_env, "r") as f:
            final_runtime_env = yaml.safe_load(f)

    elif runtime_env_json is not None:
        final_runtime_env = json.loads(runtime_env_json)

    if working_dir is not None:
        if "working_dir" in final_runtime_env:
            cli_logger.warning(
                "Overriding runtime_env working_dir with --working-dir option"
            )

        final_runtime_env["working_dir"] = working_dir

    job_id = client.submit_job(
        entrypoint=list2cmdline(entrypoint),
        job_id=job_id,
        runtime_env=final_runtime_env,
    )

    _log_big_success_msg(f"Job '{job_id}' submitted successfully")

    with cli_logger.group("Next steps"):
        cli_logger.print("Query the logs of the job:")
        with cli_logger.indented():
            cli_logger.print(cf.bold(f"ray job logs {job_id}"))

        cli_logger.print("Query the status of the job:")
        with cli_logger.indented():
            cli_logger.print(cf.bold(f"ray job status {job_id}"))

        cli_logger.print("Request the job to be stopped:")
        with cli_logger.indented():
            cli_logger.print(cf.bold(f"ray job stop {job_id}"))

    cli_logger.newline()
    sdk_version = client.get_version()
    # sdk version 0 does not have log streaming
    if not no_wait:
        if int(sdk_version) > 0:
            cli_logger.print(
                "Tailing logs until the job exits " "(disable with --no-wait):"
            )
            asyncio.get_event_loop().run_until_complete(_tail_logs(client, job_id))
        else:
            cli_logger.warning(
                "Tailing logs is not enabled for job sdk client version "
                f"{sdk_version}. Please upgrade your ray to latest version "
                "for this feature."
            )
Пример #2
0
    def create_node(self, node_config, tags, count) -> Dict[str, Any]:
        """Creates instances.

        Returns dict mapping instance id to ec2.Instance object for the created
        instances.
        """
        # sort tags by key to support deterministic unit test stubbing
        tags = OrderedDict(sorted(copy.deepcopy(tags).items()))

        reused_nodes_dict = {}
        # Try to reuse previously stopped nodes with compatible configs
        if self.cache_stopped_nodes:
            # TODO(ekl) this is breaking the abstraction boundary a little by
            # peeking into the tag set.
            filters = [
                {
                    "Name": "instance-state-name",
                    "Values": ["stopped", "stopping"],
                },
                {
                    "Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME),
                    "Values": [self.cluster_name],
                },
                {
                    "Name": "tag:{}".format(TAG_RAY_NODE_KIND),
                    "Values": [tags[TAG_RAY_NODE_KIND]],
                },
                {
                    "Name": "tag:{}".format(TAG_RAY_LAUNCH_CONFIG),
                    "Values": [tags[TAG_RAY_LAUNCH_CONFIG]],
                },
            ]
            # This tag may not always be present.
            if TAG_RAY_USER_NODE_TYPE in tags:
                filters.append({
                    "Name": "tag:{}".format(TAG_RAY_USER_NODE_TYPE),
                    "Values": [tags[TAG_RAY_USER_NODE_TYPE]],
                })

            reuse_nodes = list(
                self.ec2.instances.filter(Filters=filters))[:count]
            reuse_node_ids = [n.id for n in reuse_nodes]
            reused_nodes_dict = {n.id: n for n in reuse_nodes}
            if reuse_nodes:
                cli_logger.print(
                    # todo: handle plural vs singular?
                    "Reusing nodes {}. "
                    "To disable reuse, set `cache_stopped_nodes: False` "
                    "under `provider` in the cluster configuration.",
                    cli_logger.render_list(reuse_node_ids))

                # todo: timed?
                with cli_logger.group("Stopping instances to reuse"):
                    for node in reuse_nodes:
                        self.tag_cache[node.id] = from_aws_format(
                            {x["Key"]: x["Value"]
                             for x in node.tags})
                        if node.state["Name"] == "stopping":
                            cli_logger.print("Waiting for instance {} to stop",
                                             node.id)
                            node.wait_until_stopped()

                self.ec2.meta.client.start_instances(
                    InstanceIds=reuse_node_ids)
                for node_id in reuse_node_ids:
                    self.set_node_tags(node_id, tags)
                count -= len(reuse_node_ids)

        created_nodes_dict = {}
        if count:
            created_nodes_dict = self._create_node(node_config, tags, count)

        all_created_nodes = reused_nodes_dict
        all_created_nodes.update(created_nodes_dict)
        return all_created_nodes
Пример #3
0
    def run(
            self,
            cmd,
            timeout=120,
            exit_on_fail=False,
            port_forward=None,
            with_output=False,
            environment_variables: Dict[str, object] = None,
            run_env="auto",  # Unused argument.
            ssh_options_override_ssh_key="",
            shutdown_after_run=False,
            silent=False):
        if shutdown_after_run:
            cmd += "; sudo shutdown -h now"
        if ssh_options_override_ssh_key:
            ssh_options = SSHOptions(ssh_options_override_ssh_key)
        else:
            ssh_options = self.ssh_options

        assert isinstance(
            ssh_options, SSHOptions
        ), "ssh_options must be of type SSHOptions, got {}".format(
            type(ssh_options))

        self._set_ssh_ip_if_required()

        if is_using_login_shells():
            ssh = ["ssh", "-tt"]
        else:
            ssh = ["ssh"]

        if port_forward:
            with cli_logger.group("Forwarding ports"):
                if not isinstance(port_forward, list):
                    port_forward = [port_forward]
                for local, remote in port_forward:
                    cli_logger.verbose(
                        "Forwarding port {} to port {} on localhost.",
                        cf.bold(local), cf.bold(remote))  # todo: msg
                    ssh += ["-L", "{}:localhost:{}".format(remote, local)]

        final_cmd = ssh + ssh_options.to_ssh_options_list(
            timeout=timeout) + ["{}@{}".format(self.ssh_user, self.ssh_ip)]
        if cmd:
            if environment_variables:
                cmd = _with_environment_variables(cmd, environment_variables)
            if is_using_login_shells():
                final_cmd += _with_interactive(cmd)
            else:
                final_cmd += [cmd]
        else:
            # We do this because `-o ControlMaster` causes the `-N` flag to
            # still create an interactive shell in some ssh versions.
            final_cmd.append("while true; do sleep 86400; done")

        cli_logger.verbose("Running `{}`", cf.bold(cmd))
        with cli_logger.indented():
            cli_logger.very_verbose("Full command is `{}`",
                                    cf.bold(" ".join(final_cmd)))

        if cli_logger.verbosity > 0:
            with cli_logger.indented():
                return self._run_helper(final_cmd,
                                        with_output,
                                        exit_on_fail,
                                        silent=silent)
        else:
            return self._run_helper(final_cmd,
                                    with_output,
                                    exit_on_fail,
                                    silent=silent)
Пример #4
0
    def _create_node(self, node_config, tags, count):
        created_nodes_dict = {}

        tags = to_aws_format(tags)
        conf = node_config.copy()

        tag_pairs = [{
            "Key": TAG_RAY_CLUSTER_NAME,
            "Value": self.cluster_name,
        }]
        for k, v in tags.items():
            tag_pairs.append({
                "Key": k,
                "Value": v,
            })
        tag_specs = [{
            "ResourceType": "instance",
            "Tags": tag_pairs,
        }]
        user_tag_specs = conf.get("TagSpecifications", [])
        # Allow users to add tags and override values of existing
        # tags with their own. This only applies to the resource type
        # "instance". All other resource types are appended to the list of
        # tag specs.
        for user_tag_spec in user_tag_specs:
            if user_tag_spec["ResourceType"] == "instance":
                for user_tag in user_tag_spec["Tags"]:
                    exists = False
                    for tag in tag_specs[0]["Tags"]:
                        if user_tag["Key"] == tag["Key"]:
                            exists = True
                            tag["Value"] = user_tag["Value"]
                            break
                    if not exists:
                        tag_specs[0]["Tags"] += [user_tag]
            else:
                tag_specs += [user_tag_spec]

        # SubnetIds is not a real config key: we must resolve to a
        # single SubnetId before invoking the AWS API.
        subnet_ids = conf.pop("SubnetIds")

        for attempt in range(1, BOTO_CREATE_MAX_RETRIES + 1):
            try:
                subnet_id = subnet_ids[self.subnet_idx % len(subnet_ids)]

                self.subnet_idx += 1
                conf.update({
                    "MinCount": 1,
                    "MaxCount": count,
                    "SubnetId": subnet_id,
                    "TagSpecifications": tag_specs
                })
                created = self.ec2_fail_fast.create_instances(**conf)
                created_nodes_dict = {n.id: n for n in created}

                # todo: timed?
                # todo: handle plurality?
                with cli_logger.group(
                        "Launched {} nodes",
                        count,
                        _tags=dict(subnet_id=subnet_id)):
                    for instance in created:
                        # NOTE(maximsmol): This is needed for mocking
                        # boto3 for tests. This is likely a bug in moto
                        # but AWS docs don't seem to say.
                        # You can patch moto/ec2/responses/instances.py
                        # to fix this (add <stateReason> to EC2_RUN_INSTANCES)

                        # The correct value is technically
                        # {"code": "0", "Message": "pending"}
                        state_reason = instance.state_reason or {
                            "Message": "pending"
                        }

                        cli_logger.print(
                            "Launched instance {}",
                            instance.instance_id,
                            _tags=dict(
                                state=instance.state["Name"],
                                info=state_reason["Message"]))
                break
            except botocore.exceptions.ClientError as exc:
                if attempt == BOTO_CREATE_MAX_RETRIES:
                    cli_logger.abort(
                        "Failed to launch instances. Max attempts exceeded.",
                        exc=exc,
                    )
                else:
                    cli_logger.warning(
                        "create_instances: Attempt failed with {}, retrying.",
                        exc)
        return created_nodes_dict
Пример #5
0
    def create_node(self, node_config: Dict[str, Any], tags: Dict[str, str],
                    count: int) -> Optional[Dict[str, Any]]:
        filter_tags = [
            {
                "Key": TAG_RAY_CLUSTER_NAME,
                "Value": self.cluster_name,
            },
            {
                "Key": TAG_RAY_NODE_KIND,
                "Value": tags[TAG_RAY_NODE_KIND]
            },
            {
                "Key": TAG_RAY_USER_NODE_TYPE,
                "Value": tags[TAG_RAY_USER_NODE_TYPE]
            },
            {
                "Key": TAG_RAY_LAUNCH_CONFIG,
                "Value": tags[TAG_RAY_LAUNCH_CONFIG]
            },
            {
                "Key": TAG_RAY_NODE_NAME,
                "Value": tags[TAG_RAY_NODE_NAME]
            },
        ]

        reused_nodes_dict = {}
        if self.cache_stopped_nodes:
            reuse_nodes_candidate = self.acs.describe_instances(
                tags=filter_tags)
            if reuse_nodes_candidate:
                with cli_logger.group("Stopping instances to reuse"):
                    reuse_node_ids = []
                    for node in reuse_nodes_candidate:
                        node_id = node.get("InstanceId")
                        status = node.get("Status")
                        if status != STOPPING and status != STOPPED:
                            continue
                        if status == STOPPING:
                            # wait for node stopped
                            while (self.acs.describe_instances(
                                    instance_ids=[node_id])[0].get("Status") ==
                                   STOPPING):
                                logging.info("wait for %s stop" % node_id)
                                time.sleep(STOPPING_NODE_DELAY)
                        # logger.info("reuse %s" % node_id)
                        reuse_node_ids.append(node_id)
                        reused_nodes_dict[node.get("InstanceId")] = node
                        self.acs.start_instance(node_id)
                        self.tag_cache[node_id] = node.get("Tags")
                        self.set_node_tags(node_id, tags)
                        if len(reuse_node_ids) == count:
                            break
                count -= len(reuse_node_ids)

        created_nodes_dict = {}
        if count > 0:
            filter_tags.append({
                "Key": TAG_RAY_NODE_STATUS,
                "Value": tags[TAG_RAY_NODE_STATUS]
            })
            instance_id_sets = self.acs.run_instances(
                instance_type=node_config["InstanceType"],
                image_id=node_config["ImageId"],
                tags=filter_tags,
                amount=count,
                vswitch_id=self.provider_config["v_switch_id"],
                security_group_id=self.provider_config["security_group_id"],
                key_pair_name=self.provider_config["key_name"],
            )
            instances = self.acs.describe_instances(
                instance_ids=instance_id_sets)

            if instances is not None:
                for instance in instances:
                    created_nodes_dict[instance.get("InstanceId")] = instance

        all_created_nodes = reused_nodes_dict
        all_created_nodes.update(created_nodes_dict)
        return all_created_nodes
Пример #6
0
def log_to_cli(config):
    provider_name = _PROVIDER_PRETTY_NAMES.get("aws", None)

    cli_logger.doassert(provider_name is not None,
                        "Could not find a pretty name for the AWS provider.")

    with cli_logger.group("{} config", provider_name):

        def same_everywhere(key):
            return config["head_node"][key] == config["worker_nodes"][key]

        def print_info(resource_string,
                       key,
                       head_src_key,
                       workers_src_key,
                       allowed_tags=["default"],
                       list_value=False):

            head_tags = {}
            workers_tags = {}

            if _log_info[head_src_key] in allowed_tags:
                head_tags[_log_info[head_src_key]] = True
            if _log_info[workers_src_key] in allowed_tags:
                workers_tags[_log_info[workers_src_key]] = True

            head_value_str = config["head_node"][key]
            if list_value:
                head_value_str = cli_logger.render_list(head_value_str)

            if same_everywhere(key):
                cli_logger.labeled_value(  # todo: handle plural vs singular?
                    resource_string + " (head & workers)",
                    "{}",
                    head_value_str,
                    _tags=head_tags)
            else:
                workers_value_str = config["worker_nodes"][key]
                if list_value:
                    workers_value_str = cli_logger.render_list(
                        workers_value_str)

                cli_logger.labeled_value(resource_string + " (head)",
                                         "{}",
                                         head_value_str,
                                         _tags=head_tags)
                cli_logger.labeled_value(resource_string + " (workers)",
                                         "{}",
                                         workers_value_str,
                                         _tags=workers_tags)

        tags = {"default": _log_info["head_instance_profile_src"] == "default"}
        cli_logger.labeled_value(
            "IAM Profile",
            "{}",
            _arn_to_name(config["head_node"]["IamInstanceProfile"]["Arn"]),
            _tags=tags)

        if ("KeyName" in config["head_node"]
                and "KeyName" in config["worker_nodes"]):
            print_info("EC2 Key pair", "KeyName", "keypair_src", "keypair_src")

        print_info("VPC Subnets",
                   "SubnetIds",
                   "head_subnet_src",
                   "workers_subnet_src",
                   list_value=True)
        print_info("EC2 Security groups",
                   "SecurityGroupIds",
                   "head_security_group_src",
                   "workers_security_group_src",
                   list_value=True)
        print_info("EC2 AMI",
                   "ImageId",
                   "head_ami_src",
                   "workers_ami_src",
                   allowed_tags=["dlami"])

    cli_logger.newline()
Пример #7
0
    cf.bold("Bold ") + cf.italic("Italic ") + cf.underlined("Underlined"))
cli_logger.labeled_value("Label", "value")
cli_logger.print("List: {}", cli_logger.render_list([1, 2, 3]))
cli_logger.newline()
cli_logger.very_verbose("Very verbose")
cli_logger.verbose("Verbose")
cli_logger.verbose_warning("Verbose warning")
cli_logger.verbose_error("Verbose error")
cli_logger.print("Info")
cli_logger.success("Success")
cli_logger.warning("Warning")
cli_logger.error("Error")
cli_logger.newline()
try:
    cli_logger.abort("Abort")
except Exception:
    pass
try:
    cli_logger.doassert(False, "Assert")
except Exception:
    pass
cli_logger.newline()
cli_logger.confirm(True, "example")
cli_logger.newline()
with cli_logger.indented():
    cli_logger.print("Indented")
with cli_logger.group("Group"):
    cli_logger.print("Group contents")
with cli_logger.verbatim_error_ctx("Verbtaim error"):
    cli_logger.print("Error contents")
Пример #8
0
def start(node_ip_address, address, port, redis_password, redis_shard_ports,
          object_manager_port, node_manager_port, gcs_server_port,
          min_worker_port, max_worker_port, worker_port_list, memory,
          object_store_memory, redis_max_memory, num_cpus, num_gpus, resources,
          head, include_dashboard, dashboard_host, dashboard_port, block,
          plasma_directory, autoscaling_config, no_redirect_worker_output,
          no_redirect_output, plasma_store_socket_name, raylet_socket_name,
          temp_dir, java_worker_options, load_code_from_local,
          code_search_path, system_config, lru_evict,
          enable_object_reconstruction, metrics_export_port, log_style,
          log_color, verbose):
    """Start Ray processes manually on the local machine."""
    cli_logger.configure(log_style, log_color, verbose)
    if gcs_server_port and not head:
        raise ValueError(
            "gcs_server_port can be only assigned when you specify --head.")

    # Convert hostnames to numerical IP address.
    if node_ip_address is not None:
        node_ip_address = services.address_to_ip(node_ip_address)

    redis_address = None
    if address is not None:
        (redis_address, redis_address_ip,
         redis_address_port) = services.validate_redis_address(address)
    try:
        resources = json.loads(resources)
    except Exception:
        cli_logger.error("`{}` is not a valid JSON string.",
                         cf.bold("--resources"))
        cli_logger.abort(
            "Valid values look like this: `{}`",
            cf.bold("--resources='\"CustomResource3\": 1, "
                    "\"CustomResource2\": 2}'"))

        raise Exception("Unable to parse the --resources argument using "
                        "json.loads. Try using a format like\n\n"
                        "    --resources='{\"CustomResource1\": 3, "
                        "\"CustomReseource2\": 2}'")

    redirect_worker_output = None if not no_redirect_worker_output else True
    redirect_output = None if not no_redirect_output else True
    ray_params = ray.parameter.RayParams(
        node_ip_address=node_ip_address,
        min_worker_port=min_worker_port,
        max_worker_port=max_worker_port,
        worker_port_list=worker_port_list,
        object_manager_port=object_manager_port,
        node_manager_port=node_manager_port,
        gcs_server_port=gcs_server_port,
        memory=memory,
        object_store_memory=object_store_memory,
        redis_password=redis_password,
        redirect_worker_output=redirect_worker_output,
        redirect_output=redirect_output,
        num_cpus=num_cpus,
        num_gpus=num_gpus,
        resources=resources,
        plasma_directory=plasma_directory,
        huge_pages=False,
        plasma_store_socket_name=plasma_store_socket_name,
        raylet_socket_name=raylet_socket_name,
        temp_dir=temp_dir,
        include_dashboard=include_dashboard,
        dashboard_host=dashboard_host,
        dashboard_port=dashboard_port,
        java_worker_options=java_worker_options,
        load_code_from_local=load_code_from_local,
        code_search_path=code_search_path,
        _system_config=system_config,
        lru_evict=lru_evict,
        enable_object_reconstruction=enable_object_reconstruction,
        metrics_export_port=metrics_export_port)
    if head:
        # Use default if port is none, allocate an available port if port is 0
        if port is None:
            port = ray_constants.DEFAULT_PORT

        if port == 0:
            with socket() as s:
                s.bind(("", 0))
                port = s.getsockname()[1]

        num_redis_shards = None
        # Start Ray on the head node.
        if redis_shard_ports is not None:
            redis_shard_ports = redis_shard_ports.split(",")
            # Infer the number of Redis shards from the ports if the number is
            # not provided.
            num_redis_shards = len(redis_shard_ports)

        if redis_address is not None:
            cli_logger.abort(
                "`{}` starts a new Redis server, `{}` should not be set.",
                cf.bold("--head"), cf.bold("--address"))

            raise Exception("If --head is passed in, a Redis server will be "
                            "started, so a Redis address should not be "
                            "provided.")

        node_ip_address = services.get_node_ip_address()

        # Get the node IP address if one is not provided.
        ray_params.update_if_absent(node_ip_address=node_ip_address)
        cli_logger.labeled_value("Local node IP", ray_params.node_ip_address)
        ray_params.update_if_absent(
            redis_port=port,
            redis_shard_ports=redis_shard_ports,
            redis_max_memory=redis_max_memory,
            num_redis_shards=num_redis_shards,
            redis_max_clients=None,
            autoscaling_config=autoscaling_config,
        )

        # Fail early when starting a new cluster when one is already running
        if address is None:
            default_address = f"{node_ip_address}:{port}"
            redis_addresses = services.find_redis_address(default_address)
            if len(redis_addresses) > 0:
                raise ConnectionError(
                    f"Ray is already running at {default_address}. "
                    f"Please specify a different port using the `--port`"
                    f" command to `ray start`.")

        node = ray.node.Node(
            ray_params, head=True, shutdown_at_exit=block, spawn_reaper=block)
        redis_address = node.redis_address

        # this is a noop if new-style is not set, so the old logger calls
        # are still in place
        cli_logger.newline()
        startup_msg = "Ray runtime started."
        cli_logger.success("-" * len(startup_msg))
        cli_logger.success(startup_msg)
        cli_logger.success("-" * len(startup_msg))
        cli_logger.newline()
        with cli_logger.group("Next steps"):
            cli_logger.print(
                "To connect to this Ray runtime from another node, run")
            cli_logger.print(
                cf.bold("  ray start --address='{}'{}"), redis_address,
                f" --redis-password='******'"
                if redis_password else "")
            cli_logger.newline()
            cli_logger.print("Alternatively, use the following Python code:")
            with cli_logger.indented():
                with cf.with_style("monokai") as c:
                    cli_logger.print("{} ray", c.magenta("import"))
                    cli_logger.print(
                        "ray{}init(address{}{}{})", c.magenta("."),
                        c.magenta("="), c.yellow("'auto'"),
                        ", _redis_password{}{}".format(
                            c.magenta("="),
                            c.yellow("'" + redis_password + "'"))
                        if redis_password else "")
            cli_logger.newline()
            cli_logger.print(
                cf.underlined("If connection fails, check your "
                              "firewall settings and "
                              "network configuration."))
            cli_logger.newline()
            cli_logger.print("To terminate the Ray runtime, run")
            cli_logger.print(cf.bold("  ray stop"))
    else:
        # Start Ray on a non-head node.
        if not (port is None):
            cli_logger.abort("`{}` should not be specified without `{}`.",
                             cf.bold("--port"), cf.bold("--head"))

            raise Exception("If --head is not passed in, --port is not "
                            "allowed.")
        if redis_shard_ports is not None:
            cli_logger.abort("`{}` should not be specified without `{}`.",
                             cf.bold("--redis-shard-ports"), cf.bold("--head"))

            raise Exception("If --head is not passed in, --redis-shard-ports "
                            "is not allowed.")
        if redis_address is None:
            cli_logger.abort("`{}` is required unless starting with `{}`.",
                             cf.bold("--address"), cf.bold("--head"))

            raise Exception("If --head is not passed in, --address must "
                            "be provided.")
        if include_dashboard:
            cli_logger.abort("`{}` should not be specified without `{}`.",
                             cf.bold("--include-dashboard"), cf.bold("--head"))

            raise ValueError(
                "If --head is not passed in, the --include-dashboard"
                "flag is not relevant.")

        # Wait for the Redis server to be started. And throw an exception if we
        # can't connect to it.
        services.wait_for_redis_to_start(
            redis_address_ip, redis_address_port, password=redis_password)

        # Create a Redis client.
        redis_client = services.create_redis_client(
            redis_address, password=redis_password)

        # Check that the version information on this node matches the version
        # information that the cluster was started with.
        services.check_version_info(redis_client)

        # Get the node IP address if one is not provided.
        ray_params.update_if_absent(
            node_ip_address=services.get_node_ip_address(redis_address))

        cli_logger.labeled_value("Local node IP", ray_params.node_ip_address)

        # Check that there aren't already Redis clients with the same IP
        # address connected with this Redis instance. This raises an exception
        # if the Redis server already has clients on this node.
        check_no_existing_redis_clients(ray_params.node_ip_address,
                                        redis_client)
        ray_params.update(redis_address=redis_address)
        node = ray.node.Node(
            ray_params, head=False, shutdown_at_exit=block, spawn_reaper=block)

        cli_logger.newline()
        startup_msg = "Ray runtime started."
        cli_logger.success("-" * len(startup_msg))
        cli_logger.success(startup_msg)
        cli_logger.success("-" * len(startup_msg))
        cli_logger.newline()
        cli_logger.print("To terminate the Ray runtime, run")
        cli_logger.print(cf.bold("  ray stop"))

    if block:
        cli_logger.newline()
        with cli_logger.group(cf.bold("--block")):
            cli_logger.print(
                "This command will now block until terminated by a signal.")
            cli_logger.print(
                "Runing subprocesses are monitored and a message will be "
                "printed if any of them terminate unexpectedly.")

        while True:
            time.sleep(1)
            deceased = node.dead_processes()
            if len(deceased) > 0:
                cli_logger.newline()
                cli_logger.error("Some Ray subprcesses exited unexpectedly:")

                with cli_logger.indented():
                    for process_type, process in deceased:
                        cli_logger.error(
                            "{}",
                            cf.bold(str(process_type)),
                            _tags={"exit code": str(process.returncode)})

                # shutdown_at_exit will handle cleanup.
                cli_logger.newline()
                cli_logger.error("Remaining processes will be killed.")
                sys.exit(1)
Пример #9
0
    def do_update(self):
        self.provider.set_node_tags(
            self.node_id, {TAG_RAY_NODE_STATUS: STATUS_WAITING_FOR_SSH})
        cli_logger.labeled_value("New status", STATUS_WAITING_FOR_SSH)

        deadline = time.time() + NODE_START_WAIT_S
        self.wait_ready(deadline)

        node_tags = self.provider.node_tags(self.node_id)
        logger.debug("Node tags: {}".format(str(node_tags)))

        if node_tags.get(TAG_RAY_RUNTIME_CONFIG) == self.runtime_hash:
            # When resuming from a stopped instance the runtime_hash may be the
            # same, but the container will not be started.
            self.cmd_runner.run_init(as_head=self.is_head_node,
                                     file_mounts=self.file_mounts)

        # runtime_hash will only change whenever the user restarts
        # or updates their cluster with `get_or_create_head_node`
        if node_tags.get(TAG_RAY_RUNTIME_CONFIG) == self.runtime_hash and (
                not self.file_mounts_contents_hash
                or node_tags.get(TAG_RAY_FILE_MOUNTS_CONTENTS)
                == self.file_mounts_contents_hash):
            # todo: we lie in the confirmation message since
            # full setup might be cancelled here
            cli_logger.print(
                "Configuration already up to date, "
                "skipping file mounts, initalization and setup commands.",
                _numbered=("[]", "2-5", 6))
            cli_logger.old_info(logger,
                                "{}{} already up-to-date, skip to ray start",
                                self.log_prefix, self.node_id)

        else:
            cli_logger.print("Updating cluster configuration.",
                             _tags=dict(hash=self.runtime_hash))

            self.provider.set_node_tags(
                self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SYNCING_FILES})
            cli_logger.labeled_value("New status", STATUS_SYNCING_FILES)
            self.sync_file_mounts(self.rsync_up, step_numbers=(2, 6))

            # Only run setup commands if runtime_hash has changed because
            # we don't want to run setup_commands every time the head node
            # file_mounts folders have changed.
            if node_tags.get(TAG_RAY_RUNTIME_CONFIG) != self.runtime_hash:
                # Run init commands
                self.provider.set_node_tags(
                    self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SETTING_UP})
                cli_logger.labeled_value("New status", STATUS_SETTING_UP)

                if self.initialization_commands:
                    with cli_logger.group("Running initialization commands",
                                          _numbered=("[]", 3, 5)):
                        with LogTimer(self.log_prefix +
                                      "Initialization commands",
                                      show_status=True):
                            for cmd in self.initialization_commands:
                                try:
                                    # Overriding the existing SSHOptions class
                                    # with a new SSHOptions class that uses
                                    # this ssh_private_key as its only __init__
                                    # argument.
                                    # Run outside docker.
                                    self.cmd_runner.run(
                                        cmd,
                                        ssh_options_override_ssh_key=self.
                                        auth_config.get("ssh_private_key"),
                                        run_env="host")
                                except ProcessRunnerError as e:
                                    if e.msg_type == "ssh_command_failed":
                                        cli_logger.error("Failed.")
                                        cli_logger.error(
                                            "See above for stderr.")

                                    raise click.ClickException(
                                        "Initialization command failed."
                                    ) from None
                else:
                    cli_logger.print("No initialization commands to run.",
                                     _numbered=("[]", 3, 6))
                self.cmd_runner.run_init(as_head=self.is_head_node,
                                         file_mounts=self.file_mounts)
                if self.setup_commands:
                    with cli_logger.group(
                            "Running setup commands",
                            # todo: fix command numbering
                            _numbered=("[]", 4, 6)):
                        with LogTimer(self.log_prefix + "Setup commands",
                                      show_status=True):

                            total = len(self.setup_commands)
                            for i, cmd in enumerate(self.setup_commands):
                                if cli_logger.verbosity == 0 and len(cmd) > 30:
                                    cmd_to_print = cf.bold(cmd[:30]) + "..."
                                else:
                                    cmd_to_print = cf.bold(cmd)

                                cli_logger.print("{}",
                                                 cmd_to_print,
                                                 _numbered=("()", i, total))

                                try:
                                    # Runs in the container if docker is in use
                                    self.cmd_runner.run(cmd, run_env="auto")
                                except ProcessRunnerError as e:
                                    if e.msg_type == "ssh_command_failed":
                                        cli_logger.error("Failed.")
                                        cli_logger.error(
                                            "See above for stderr.")

                                    raise click.ClickException(
                                        "Setup command failed.")
                else:
                    cli_logger.print("No setup commands to run.",
                                     _numbered=("[]", 4, 6))

        with cli_logger.group("Starting the Ray runtime",
                              _numbered=("[]", 6, 6)):
            with LogTimer(self.log_prefix + "Ray start commands",
                          show_status=True):
                for cmd in self.ray_start_commands:
                    if self.node_resources:
                        env_vars = {
                            ray_constants.RESOURCES_ENVIRONMENT_VARIABLE:
                            self.node_resources
                        }
                    else:
                        env_vars = {}
                    try:
                        old_redirected = cmd_output_util.is_output_redirected()
                        cmd_output_util.set_output_redirected(False)
                        # Runs in the container if docker is in use
                        self.cmd_runner.run(cmd,
                                            environment_variables=env_vars,
                                            run_env="auto")
                        cmd_output_util.set_output_redirected(old_redirected)
                    except ProcessRunnerError as e:
                        if e.msg_type == "ssh_command_failed":
                            cli_logger.error("Failed.")
                            cli_logger.error("See above for stderr.")

                        raise click.ClickException("Start command failed.")
Пример #10
0
def log_to_cli(config: Dict[str, Any]) -> None:
    provider_name = _PROVIDER_PRETTY_NAMES.get("aws", None)

    cli_logger.doassert(provider_name is not None,
                        "Could not find a pretty name for the AWS provider.")

    head_node_type = config["head_node_type"]
    head_node_config = config["available_node_types"][head_node_type][
        "node_config"]

    with cli_logger.group("{} config", provider_name):

        def print_info(resource_string: str,
                       key: str,
                       src_key: str,
                       allowed_tags: Optional[List[str]] = None,
                       list_value: bool = False) -> None:
            if allowed_tags is None:
                allowed_tags = ["default"]

            node_tags = {}

            # set of configurations corresponding to `key`
            unique_settings = set()

            for node_type_key, node_type in config[
                    "available_node_types"].items():
                node_tags[node_type_key] = {}
                tag = _log_info[src_key][node_type_key]
                if tag in allowed_tags:
                    node_tags[node_type_key][tag] = True
                setting = node_type["node_config"].get(key)

                if list_value:
                    unique_settings.add(tuple(setting))
                else:
                    unique_settings.add(setting)

            head_value_str = head_node_config[key]
            if list_value:
                head_value_str = cli_logger.render_list(head_value_str)

            if len(unique_settings) == 1:
                # all node types are configured the same, condense
                # log output
                cli_logger.labeled_value(
                    resource_string + " (all available node types)",
                    "{}",
                    head_value_str,
                    _tags=node_tags[config["head_node_type"]])
            else:
                # do head node type first
                cli_logger.labeled_value(resource_string +
                                         f" ({head_node_type})",
                                         "{}",
                                         head_value_str,
                                         _tags=node_tags[head_node_type])

                # go through remaining types
                for node_type_key, node_type in config[
                        "available_node_types"].items():
                    if node_type_key == head_node_type:
                        continue
                    workers_value_str = node_type["node_config"][key]
                    if list_value:
                        workers_value_str = cli_logger.render_list(
                            workers_value_str)
                    cli_logger.labeled_value(resource_string +
                                             f" ({node_type_key})",
                                             "{}",
                                             workers_value_str,
                                             _tags=node_tags[node_type_key])

        tags = {"default": _log_info["head_instance_profile_src"] == "default"}
        # head_node_config is the head_node_type's config,
        # config["head_node"] is a field that gets applied only to the actual
        # head node (and not workers of the head's node_type)
        assert ("IamInstanceProfile" in head_node_config
                or "IamInstanceProfile" in config["head_node"])
        if "IamInstanceProfile" in head_node_config:
            # If the user manually configured the role we're here.
            IamProfile = head_node_config["IamInstanceProfile"]
        elif "IamInstanceProfile" in config["head_node"]:
            # If we filled the default IAM role, we're here.
            IamProfile = config["head_node"]["IamInstanceProfile"]
        profile_arn = IamProfile.get("Arn")
        profile_name = _arn_to_name(profile_arn) \
            if profile_arn \
            else IamProfile["Name"]
        cli_logger.labeled_value("IAM Profile", "{}", profile_name, _tags=tags)

        if all("KeyName" in node_type["node_config"]
               for node_type in config["available_node_types"].values()):
            print_info("EC2 Key pair", "KeyName", "keypair_src")

        print_info("VPC Subnets", "SubnetIds", "subnet_src", list_value=True)
        print_info("EC2 Security groups",
                   "SecurityGroupIds",
                   "security_group_src",
                   list_value=True)
        print_info("EC2 AMI", "ImageId", "ami_src", allowed_tags=["dlami"])

    cli_logger.newline()
Пример #11
0
    def sync_file_mounts(self, sync_cmd, step_numbers=(0, 2)):
        # step_numbers is (# of previous steps, total steps)
        previous_steps, total_steps = step_numbers

        nolog_paths = []
        if cli_logger.verbosity == 0:
            nolog_paths = [
                "~/ray_bootstrap_key.pem", "~/ray_bootstrap_config.yaml"
            ]

        def do_sync(remote_path, local_path, allow_non_existing_paths=False):
            if allow_non_existing_paths and not os.path.exists(local_path):
                cli_logger.print("sync: {} does not exist. Skipping.",
                                 local_path)
                # Ignore missing source files. In the future we should support
                # the --delete-missing-args command to delete files that have
                # been removed
                return

            assert os.path.exists(local_path), local_path

            if os.path.isdir(local_path):
                if not local_path.endswith("/"):
                    local_path += "/"
                if not remote_path.endswith("/"):
                    remote_path += "/"

            with LogTimer(self.log_prefix +
                          "Synced {} to {}".format(local_path, remote_path)):
                is_docker = (self.docker_config
                             and self.docker_config["container_name"] != "")
                if not is_docker:
                    # The DockerCommandRunner handles this internally.
                    self.cmd_runner.run("mkdir -p {}".format(
                        os.path.dirname(remote_path)),
                                        run_env="host")
                sync_cmd(local_path, remote_path, file_mount=True)

                if remote_path not in nolog_paths:
                    # todo: timed here?
                    cli_logger.print("{} from {}", cf.bold(remote_path),
                                     cf.bold(local_path))

        # Rsync file mounts
        with cli_logger.group("Processing file mounts",
                              _numbered=("[]", previous_steps + 1,
                                         total_steps)):
            for remote_path, local_path in self.file_mounts.items():
                do_sync(remote_path, local_path)
            previous_steps += 1

        if self.cluster_synced_files:
            with cli_logger.group("Processing worker file mounts",
                                  _numbered=("[]", previous_steps + 1,
                                             total_steps)):
                cli_logger.print("synced files: {}",
                                 str(self.cluster_synced_files))
                for path in self.cluster_synced_files:
                    do_sync(path, path, allow_non_existing_paths=True)
                previous_steps += 1
        else:
            cli_logger.print("No worker file mounts to sync",
                             _numbered=("[]", previous_steps + 1, total_steps))
Пример #12
0
    def do_update(self):
        self.provider.set_node_tags(
            self.node_id, {TAG_RAY_NODE_STATUS: STATUS_WAITING_FOR_SSH})
        cli_logger.labeled_value("New status", STATUS_WAITING_FOR_SSH)

        deadline = time.time() + AUTOSCALER_NODE_START_WAIT_S
        self.wait_ready(deadline)
        global_event_system.execute_callback(
            CreateClusterEvent.ssh_control_acquired)

        node_tags = self.provider.node_tags(self.node_id)
        logger.debug("Node tags: {}".format(str(node_tags)))

        if self.provider_type == "aws" and self.provider.provider_config:
            from ray.autoscaler._private.aws.cloudwatch.cloudwatch_helper \
                import CloudwatchHelper
            CloudwatchHelper(self.provider.provider_config,
                             self.node_id, self.provider.cluster_name). \
                update_from_config(self.is_head_node)

        if node_tags.get(TAG_RAY_RUNTIME_CONFIG) == self.runtime_hash:
            # When resuming from a stopped instance the runtime_hash may be the
            # same, but the container will not be started.
            init_required = self.cmd_runner.run_init(
                as_head=self.is_head_node,
                file_mounts=self.file_mounts,
                sync_run_yet=False)
            if init_required:
                node_tags[TAG_RAY_RUNTIME_CONFIG] += "-invalidate"
                # This ensures that `setup_commands` are not removed
                self.restart_only = False

        if self.restart_only:
            self.setup_commands = []

        # runtime_hash will only change whenever the user restarts
        # or updates their cluster with `get_or_create_head_node`
        if node_tags.get(TAG_RAY_RUNTIME_CONFIG) == self.runtime_hash and (
                not self.file_mounts_contents_hash
                or node_tags.get(TAG_RAY_FILE_MOUNTS_CONTENTS)
                == self.file_mounts_contents_hash):
            # todo: we lie in the confirmation message since
            # full setup might be cancelled here
            cli_logger.print(
                "Configuration already up to date, "
                "skipping file mounts, initalization and setup commands.",
                _numbered=("[]", "2-6", NUM_SETUP_STEPS))

        else:
            cli_logger.print("Updating cluster configuration.",
                             _tags=dict(hash=self.runtime_hash))

            self.provider.set_node_tags(
                self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SYNCING_FILES})
            cli_logger.labeled_value("New status", STATUS_SYNCING_FILES)
            self.sync_file_mounts(self.rsync_up,
                                  step_numbers=(1, NUM_SETUP_STEPS))

            # Only run setup commands if runtime_hash has changed because
            # we don't want to run setup_commands every time the head node
            # file_mounts folders have changed.
            if node_tags.get(TAG_RAY_RUNTIME_CONFIG) != self.runtime_hash:
                # Run init commands
                self.provider.set_node_tags(
                    self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SETTING_UP})
                cli_logger.labeled_value("New status", STATUS_SETTING_UP)

                if self.initialization_commands:
                    with cli_logger.group("Running initialization commands",
                                          _numbered=("[]", 4,
                                                     NUM_SETUP_STEPS)):
                        global_event_system.execute_callback(
                            CreateClusterEvent.run_initialization_cmd)
                        with LogTimer(self.log_prefix +
                                      "Initialization commands",
                                      show_status=True):
                            for cmd in self.initialization_commands:
                                global_event_system.execute_callback(
                                    CreateClusterEvent.run_initialization_cmd,
                                    {"command": cmd})
                                try:
                                    # Overriding the existing SSHOptions class
                                    # with a new SSHOptions class that uses
                                    # this ssh_private_key as its only __init__
                                    # argument.
                                    # Run outside docker.
                                    self.cmd_runner.run(
                                        cmd,
                                        ssh_options_override_ssh_key=self.
                                        auth_config.get("ssh_private_key"),
                                        run_env="host")
                                except ProcessRunnerError as e:
                                    if e.msg_type == "ssh_command_failed":
                                        cli_logger.error("Failed.")
                                        cli_logger.error(
                                            "See above for stderr.")

                                    raise click.ClickException(
                                        "Initialization command failed."
                                    ) from None
                else:
                    cli_logger.print("No initialization commands to run.",
                                     _numbered=("[]", 4, NUM_SETUP_STEPS))
                with cli_logger.group(
                        "Initalizing command runner",
                        # todo: fix command numbering
                        _numbered=("[]", 5, NUM_SETUP_STEPS)):
                    self.cmd_runner.run_init(as_head=self.is_head_node,
                                             file_mounts=self.file_mounts,
                                             sync_run_yet=True)
                if self.setup_commands:
                    with cli_logger.group(
                            "Running setup commands",
                            # todo: fix command numbering
                            _numbered=("[]", 6, NUM_SETUP_STEPS)):
                        global_event_system.execute_callback(
                            CreateClusterEvent.run_setup_cmd)
                        with LogTimer(self.log_prefix + "Setup commands",
                                      show_status=True):

                            total = len(self.setup_commands)
                            for i, cmd in enumerate(self.setup_commands):
                                global_event_system.execute_callback(
                                    CreateClusterEvent.run_setup_cmd,
                                    {"command": cmd})
                                if cli_logger.verbosity == 0 and len(cmd) > 30:
                                    cmd_to_print = cf.bold(cmd[:30]) + "..."
                                else:
                                    cmd_to_print = cf.bold(cmd)

                                cli_logger.print("{}",
                                                 cmd_to_print,
                                                 _numbered=("()", i, total))

                                try:
                                    # Runs in the container if docker is in use
                                    self.cmd_runner.run(cmd, run_env="auto")
                                except ProcessRunnerError as e:
                                    if e.msg_type == "ssh_command_failed":
                                        cli_logger.error("Failed.")
                                        cli_logger.error(
                                            "See above for stderr.")

                                    raise click.ClickException(
                                        "Setup command failed.")
                else:
                    cli_logger.print("No setup commands to run.",
                                     _numbered=("[]", 6, NUM_SETUP_STEPS))

        with cli_logger.group("Starting the Ray runtime",
                              _numbered=("[]", 7, NUM_SETUP_STEPS)):
            global_event_system.execute_callback(
                CreateClusterEvent.start_ray_runtime)
            with LogTimer(self.log_prefix + "Ray start commands",
                          show_status=True):
                for cmd in self.ray_start_commands:

                    # Add a resource override env variable if needed:
                    if self.provider_type == "local":
                        # Local NodeProvider doesn't need resource override.
                        env_vars = {}
                    elif self.node_resources:
                        env_vars = {
                            RESOURCES_ENVIRONMENT_VARIABLE: self.node_resources
                        }
                    else:
                        env_vars = {}

                    try:
                        old_redirected = cmd_output_util.is_output_redirected()
                        cmd_output_util.set_output_redirected(False)
                        # Runs in the container if docker is in use
                        self.cmd_runner.run(cmd,
                                            environment_variables=env_vars,
                                            run_env="auto")
                        cmd_output_util.set_output_redirected(old_redirected)
                    except ProcessRunnerError as e:
                        if e.msg_type == "ssh_command_failed":
                            cli_logger.error("Failed.")
                            cli_logger.error("See above for stderr.")

                        raise click.ClickException("Start command failed.")
            global_event_system.execute_callback(
                CreateClusterEvent.start_ray_runtime_completed)
Пример #13
0
def get_or_create_head_node(config: Dict[str, Any],
                            printable_config_file: str,
                            no_restart: bool,
                            restart_only: bool,
                            yes: bool,
                            override_cluster_name: Optional[str],
                            no_monitor_on_head: bool = False,
                            _provider: Optional[NodeProvider] = None,
                            _runner: ModuleType = subprocess) -> None:
    """Create the cluster head node, which in turn creates the workers."""
    global_event_system.execute_callback(
        CreateClusterEvent.cluster_booting_started)
    provider = (_provider or _get_node_provider(config["provider"],
                                                config["cluster_name"]))

    config = copy.deepcopy(config)
    head_node_tags = {
        TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
    }
    nodes = provider.non_terminated_nodes(head_node_tags)
    if len(nodes) > 0:
        head_node = nodes[0]
    else:
        head_node = None

    if not head_node:
        cli_logger.confirm(yes, "No head node found. "
                           "Launching a new cluster.",
                           _abort=True)

    if head_node:
        if restart_only:
            cli_logger.confirm(yes, "Updating cluster configuration and "
                               "restarting the cluster Ray runtime. "
                               "Setup commands will not be run due to `{}`.\n",
                               cf.bold("--restart-only"),
                               _abort=True)
        elif no_restart:
            cli_logger.print(
                "Cluster Ray runtime will not be restarted due "
                "to `{}`.", cf.bold("--no-restart"))
            cli_logger.confirm(yes, "Updating cluster configuration and "
                               "running setup commands.",
                               _abort=True)
        else:
            cli_logger.print(
                "Updating cluster configuration and running full setup.")
            cli_logger.confirm(
                yes,
                cf.bold("Cluster Ray runtime will be restarted."),
                _abort=True)

    cli_logger.newline()
    # TODO(ekl) this logic is duplicated in node_launcher.py (keep in sync)
    head_node_config = copy.deepcopy(config["head_node"])
    head_node_resources = None
    if "head_node_type" in config:
        head_node_type = config["head_node_type"]
        head_node_tags[TAG_RAY_USER_NODE_TYPE] = head_node_type
        head_config = config["available_node_types"][head_node_type]
        head_node_config.update(head_config["node_config"])

        # Not necessary to keep in sync with node_launcher.py
        # Keep in sync with autoscaler.py _node_resources
        head_node_resources = head_config.get("resources")

    launch_hash = hash_launch_conf(head_node_config, config["auth"])
    if head_node is None or provider.node_tags(head_node).get(
            TAG_RAY_LAUNCH_CONFIG) != launch_hash:
        with cli_logger.group("Acquiring an up-to-date head node"):
            global_event_system.execute_callback(
                CreateClusterEvent.acquiring_new_head_node)
            if head_node is not None:
                cli_logger.print(
                    "Currently running head node is out-of-date with "
                    "cluster configuration")
                cli_logger.print(
                    "hash is {}, expected {}",
                    cf.bold(
                        provider.node_tags(head_node).get(
                            TAG_RAY_LAUNCH_CONFIG)), cf.bold(launch_hash))
                cli_logger.confirm(yes, "Relaunching it.", _abort=True)

                provider.terminate_node(head_node)
                cli_logger.print("Terminated head node {}", head_node)

            head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
            head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
                config["cluster_name"])
            provider.create_node(head_node_config, head_node_tags, 1)
            cli_logger.print("Launched a new head node")

            start = time.time()
            head_node = None
            with cli_logger.group("Fetching the new head node"):
                while True:
                    if time.time() - start > 50:
                        cli_logger.abort(
                            "Head node fetch timed out.")  # todo: msg
                        raise RuntimeError("Failed to create head node.")
                    nodes = provider.non_terminated_nodes(head_node_tags)
                    if len(nodes) == 1:
                        head_node = nodes[0]
                        break
                    time.sleep(POLL_INTERVAL)
            cli_logger.newline()

    global_event_system.execute_callback(CreateClusterEvent.head_node_acquired)

    with cli_logger.group(
            "Setting up head node",
            _numbered=("<>", 1, 1),
            # cf.bold(provider.node_tags(head_node)[TAG_RAY_NODE_NAME]),
            _tags=dict()):  # add id, ARN to tags?

        # TODO(ekl) right now we always update the head node even if the
        # hash matches.
        # We could prompt the user for what they want to do here.
        # No need to pass in cluster_sync_files because we use this
        # hash to set up the head node
        (runtime_hash, file_mounts_contents_hash) = hash_runtime_conf(
            config["file_mounts"], None, config)

        if not no_monitor_on_head:
            # Return remote_config_file to avoid prematurely closing it.
            config, remote_config_file = _set_up_config_for_head_node(
                config, provider, no_restart)
            cli_logger.print("Prepared bootstrap config")

        if restart_only:
            setup_commands = []
            ray_start_commands = config["head_start_ray_commands"]
        elif no_restart:
            setup_commands = config["head_setup_commands"]
            ray_start_commands = []
        else:
            setup_commands = config["head_setup_commands"]
            ray_start_commands = config["head_start_ray_commands"]

        if not no_restart:
            warn_about_bad_start_command(ray_start_commands,
                                         no_monitor_on_head)

        updater = NodeUpdaterThread(
            node_id=head_node,
            provider_config=config["provider"],
            provider=provider,
            auth_config=config["auth"],
            cluster_name=config["cluster_name"],
            file_mounts=config["file_mounts"],
            initialization_commands=config["initialization_commands"],
            setup_commands=setup_commands,
            ray_start_commands=ray_start_commands,
            process_runner=_runner,
            runtime_hash=runtime_hash,
            file_mounts_contents_hash=file_mounts_contents_hash,
            is_head_node=True,
            node_resources=head_node_resources,
            rsync_options={
                "rsync_exclude": config.get("rsync_exclude"),
                "rsync_filter": config.get("rsync_filter")
            },
            docker_config=config.get("docker"))
        updater.start()
        updater.join()

        # Refresh the node cache so we see the external ip if available
        provider.non_terminated_nodes(head_node_tags)

        if updater.exitcode != 0:
            # todo: this does not follow the mockup and is not good enough
            cli_logger.abort("Failed to setup head node.")
            sys.exit(1)

    global_event_system.execute_callback(
        CreateClusterEvent.cluster_booting_completed, {
            "head_node_id": head_node,
        })

    monitor_str = "tail -n 100 -f /tmp/ray/session_latest/logs/monitor*"
    if override_cluster_name:
        modifiers = " --cluster-name={}".format(quote(override_cluster_name))
    else:
        modifiers = ""

    cli_logger.newline()
    with cli_logger.group("Useful commands"):
        printable_config_file = os.path.abspath(printable_config_file)
        cli_logger.print("Monitor autoscaling with")
        cli_logger.print(cf.bold("  ray exec {}{} {}"), printable_config_file,
                         modifiers, quote(monitor_str))

        cli_logger.print("Connect to a terminal on the cluster head:")
        cli_logger.print(cf.bold("  ray attach {}{}"), printable_config_file,
                         modifiers)

        remote_shell_str = updater.cmd_runner.remote_shell_command_str()
        cli_logger.print("Get a remote shell to the cluster manually:")
        cli_logger.print("  {}", remote_shell_str.strip())
Пример #14
0
    def _create_node(self, node_config, tags, count):
        created_nodes_dict = {}

        tags = to_aws_format(tags)
        conf = node_config.copy()

        tag_pairs = [{
            "Key": TAG_RAY_CLUSTER_NAME,
            "Value": self.cluster_name,
        }]
        for k, v in tags.items():
            tag_pairs.append({
                "Key": k,
                "Value": v,
            })
        if CloudwatchHelper.cloudwatch_config_exists(
                self.provider_config, CloudwatchConfigType.AGENT):
            cwa_installed = self._check_ami_cwa_installation(node_config)
            if cwa_installed:
                tag_pairs.extend([{
                    "Key": CLOUDWATCH_AGENT_INSTALLED_TAG,
                    "Value": "True",
                }])
        tag_specs = [{
            "ResourceType": "instance",
            "Tags": tag_pairs,
        }]
        user_tag_specs = conf.get("TagSpecifications", [])
        AWSNodeProvider._merge_tag_specs(tag_specs, user_tag_specs)

        # SubnetIds is not a real config key: we must resolve to a
        # single SubnetId before invoking the AWS API.
        subnet_ids = conf.pop("SubnetIds")

        # update config with min/max node counts and tag specs
        conf.update({
            "MinCount": 1,
            "MaxCount": count,
            "TagSpecifications": tag_specs
        })

        # Try to always launch in the first listed subnet.
        subnet_idx = 0
        cli_logger_tags = {}
        # NOTE: This ensures that we try ALL availability zones before
        # throwing an error.
        max_tries = max(BOTO_CREATE_MAX_RETRIES, len(subnet_ids))
        for attempt in range(1, max_tries + 1):
            try:
                if "NetworkInterfaces" in conf:
                    net_ifs = conf["NetworkInterfaces"]
                    # remove security group IDs previously copied from network
                    # interfaces (create_instances call fails otherwise)
                    conf.pop("SecurityGroupIds", None)
                    cli_logger_tags["network_interfaces"] = str(net_ifs)
                else:
                    subnet_id = subnet_ids[subnet_idx % len(subnet_ids)]
                    conf["SubnetId"] = subnet_id
                    cli_logger_tags["subnet_id"] = subnet_id

                created = self.ec2_fail_fast.create_instances(**conf)
                created_nodes_dict = {n.id: n for n in created}

                # todo: timed?
                # todo: handle plurality?
                with cli_logger.group("Launched {} nodes",
                                      count,
                                      _tags=cli_logger_tags):
                    for instance in created:
                        # NOTE(maximsmol): This is needed for mocking
                        # boto3 for tests. This is likely a bug in moto
                        # but AWS docs don't seem to say.
                        # You can patch moto/ec2/responses/instances.py
                        # to fix this (add <stateReason> to EC2_RUN_INSTANCES)

                        # The correct value is technically
                        # {"code": "0", "Message": "pending"}
                        state_reason = instance.state_reason or {
                            "Message": "pending"
                        }

                        cli_logger.print("Launched instance {}",
                                         instance.instance_id,
                                         _tags=dict(
                                             state=instance.state["Name"],
                                             info=state_reason["Message"]))
                break
            except botocore.exceptions.ClientError as exc:
                if attempt == max_tries:
                    cli_logger.abort(
                        "Failed to launch instances. Max attempts exceeded.",
                        exc=exc,
                    )
                else:
                    cli_logger.warning(
                        "create_instances: Attempt failed with {}, retrying.",
                        exc)

                # Launch failure may be due to instance type availability in
                # the given AZ
                subnet_idx += 1

        return created_nodes_dict
Пример #15
0
def get_or_create_head_node(config: Dict[str, Any],
                            config_file: str,
                            no_restart: bool,
                            restart_only: bool,
                            yes: bool,
                            override_cluster_name: Optional[str],
                            _provider: Optional[NodeProvider] = None,
                            _runner: ModuleType = subprocess) -> None:
    """Create the cluster head node, which in turn creates the workers."""
    provider = (_provider or _get_node_provider(config["provider"],
                                                config["cluster_name"]))

    config = copy.deepcopy(config)
    config_file = os.path.abspath(config_file)
    try:
        head_node_tags = {
            TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
        }
        nodes = provider.non_terminated_nodes(head_node_tags)
        if len(nodes) > 0:
            head_node = nodes[0]
        else:
            head_node = None

        if not head_node:
            cli_logger.confirm(yes, "No head node found. "
                               "Launching a new cluster.",
                               _abort=True)
            cli_logger.old_confirm("This will create a new cluster", yes)
        elif not no_restart:
            cli_logger.old_confirm("This will restart cluster services", yes)

        if head_node:
            if restart_only:
                cli_logger.confirm(
                    yes, "Updating cluster configuration and "
                    "restarting the cluster Ray runtime. "
                    "Setup commands will not be run due to `{}`.\n",
                    cf.bold("--restart-only"),
                    _abort=True)
            elif no_restart:
                cli_logger.print(
                    "Cluster Ray runtime will not be restarted due "
                    "to `{}`.", cf.bold("--no-restart"))
                cli_logger.confirm(yes, "Updating cluster configuration and "
                                   "running setup commands.",
                                   _abort=True)
            else:
                cli_logger.print(
                    "Updating cluster configuration and running full setup.")
                cli_logger.confirm(
                    yes,
                    cf.bold("Cluster Ray runtime will be restarted."),
                    _abort=True)
        cli_logger.newline()

        # TODO(ekl) this logic is duplicated in node_launcher.py (keep in sync)
        head_node_config = copy.deepcopy(config["head_node"])
        if "head_node_type" in config:
            head_node_tags[TAG_RAY_USER_NODE_TYPE] = config["head_node_type"]
            head_node_config.update(config["available_node_types"][
                config["head_node_type"]]["node_config"])

        launch_hash = hash_launch_conf(head_node_config, config["auth"])
        if head_node is None or provider.node_tags(head_node).get(
                TAG_RAY_LAUNCH_CONFIG) != launch_hash:
            with cli_logger.group("Acquiring an up-to-date head node"):
                if head_node is not None:
                    cli_logger.print(
                        "Currently running head node is out-of-date with "
                        "cluster configuration")
                    cli_logger.print(
                        "hash is {}, expected {}",
                        cf.bold(
                            provider.node_tags(head_node).get(
                                TAG_RAY_LAUNCH_CONFIG)), cf.bold(launch_hash))
                    cli_logger.confirm(yes, "Relaunching it.", _abort=True)
                    cli_logger.old_confirm(
                        "Head node config out-of-date. It will be terminated",
                        yes)

                    cli_logger.old_info(
                        logger, "get_or_create_head_node: "
                        "Shutting down outdated head node {}", head_node)

                    provider.terminate_node(head_node)
                    cli_logger.print("Terminated head node {}", head_node)

                cli_logger.old_info(
                    logger,
                    "get_or_create_head_node: Launching new head node...")

                head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
                head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
                    config["cluster_name"])
                provider.create_node(head_node_config, head_node_tags, 1)
                cli_logger.print("Launched a new head node")

                start = time.time()
                head_node = None
                with cli_logger.group("Fetching the new head node"):
                    while True:
                        if time.time() - start > 50:
                            cli_logger.abort(
                                "Head node fetch timed out.")  # todo: msg
                            raise RuntimeError("Failed to create head node.")
                        nodes = provider.non_terminated_nodes(head_node_tags)
                        if len(nodes) == 1:
                            head_node = nodes[0]
                            break
                        time.sleep(POLL_INTERVAL)
                cli_logger.newline()

        with cli_logger.group(
                "Setting up head node",
                _numbered=("<>", 1, 1),
                # cf.bold(provider.node_tags(head_node)[TAG_RAY_NODE_NAME]),
                _tags=dict()):  # add id, ARN to tags?

            # TODO(ekl) right now we always update the head node even if the
            # hash matches.
            # We could prompt the user for what they want to do here.
            # No need to pass in cluster_sync_files because we use this
            # hash to set up the head node
            (runtime_hash, file_mounts_contents_hash) = hash_runtime_conf(
                config["file_mounts"], None, config)

            cli_logger.old_info(
                logger,
                "get_or_create_head_node: Updating files on head node...")

            # Rewrite the auth config so that the head
            # node can update the workers
            remote_config = copy.deepcopy(config)

            # drop proxy options if they exist, otherwise
            # head node won't be able to connect to workers
            remote_config["auth"].pop("ssh_proxy_command", None)

            if "ssh_private_key" in config["auth"]:
                remote_key_path = "~/ray_bootstrap_key.pem"
                remote_config["auth"]["ssh_private_key"] = remote_key_path

            # Adjust for new file locations
            new_mounts = {}
            for remote_path in config["file_mounts"]:
                new_mounts[remote_path] = remote_path
            remote_config["file_mounts"] = new_mounts
            remote_config["no_restart"] = no_restart

            remote_config = provider.prepare_for_head_node(remote_config)

            # Now inject the rewritten config and SSH key into the head node
            remote_config_file = tempfile.NamedTemporaryFile(
                "w", prefix="ray-bootstrap-")
            remote_config_file.write(json.dumps(remote_config))
            remote_config_file.flush()
            config["file_mounts"].update(
                {"~/ray_bootstrap_config.yaml": remote_config_file.name})

            if "ssh_private_key" in config["auth"]:
                config["file_mounts"].update({
                    remote_key_path:
                    config["auth"]["ssh_private_key"],
                })
            cli_logger.print("Prepared bootstrap config")

            if restart_only:
                setup_commands = []
                ray_start_commands = config["head_start_ray_commands"]
            elif no_restart:
                setup_commands = config["head_setup_commands"]
                ray_start_commands = []
            else:
                setup_commands = config["head_setup_commands"]
                ray_start_commands = config["head_start_ray_commands"]

            if not no_restart:
                warn_about_bad_start_command(ray_start_commands)

            updater = NodeUpdaterThread(
                node_id=head_node,
                provider_config=config["provider"],
                provider=provider,
                auth_config=config["auth"],
                cluster_name=config["cluster_name"],
                file_mounts=config["file_mounts"],
                initialization_commands=config["initialization_commands"],
                setup_commands=setup_commands,
                ray_start_commands=ray_start_commands,
                process_runner=_runner,
                runtime_hash=runtime_hash,
                file_mounts_contents_hash=file_mounts_contents_hash,
                is_head_node=True,
                docker_config=config.get("docker"))
            updater.start()
            updater.join()

            # Refresh the node cache so we see the external ip if available
            provider.non_terminated_nodes(head_node_tags)

            if config.get("provider", {}).get("use_internal_ips",
                                              False) is True:
                head_node_ip = provider.internal_ip(head_node)
            else:
                head_node_ip = provider.external_ip(head_node)

            if updater.exitcode != 0:
                # todo: this does not follow the mockup and is not good enough
                cli_logger.abort("Failed to setup head node.")

                cli_logger.old_error(
                    logger, "get_or_create_head_node: "
                    "Updating {} failed", head_node_ip)
                sys.exit(1)

            cli_logger.old_info(
                logger, "get_or_create_head_node: "
                "Head node up-to-date, IP address is: {}", head_node_ip)

        monitor_str = "tail -n 100 -f /tmp/ray/session_latest/logs/monitor*"
        if override_cluster_name:
            modifiers = " --cluster-name={}".format(
                quote(override_cluster_name))
        else:
            modifiers = ""

        if cli_logger.old_style:
            print("To monitor autoscaling activity, you can run:\n\n"
                  "  ray exec {} {}{}\n".format(config_file,
                                                quote(monitor_str), modifiers))
            print("To open a console on the cluster:\n\n"
                  "  ray attach {}{}\n".format(config_file, modifiers))

            print("To get a remote shell to the cluster manually, run:\n\n"
                  "  {}\n".format(
                      updater.cmd_runner.remote_shell_command_str()))

        cli_logger.newline()
        with cli_logger.group("Useful commands"):
            cli_logger.print("Monitor autoscaling with")
            cli_logger.print(cf.bold("  ray exec {}{} {}"), config_file,
                             modifiers, quote(monitor_str))

            cli_logger.print("Connect to a terminal on the cluster head:")
            cli_logger.print(cf.bold("  ray attach {}{}"), config_file,
                             modifiers)

            remote_shell_str = updater.cmd_runner.remote_shell_command_str()
            cli_logger.print("Get a remote shell to the cluster manually:")
            cli_logger.print("  {}", remote_shell_str.strip())
    finally:
        provider.cleanup()