Exemplo n.º 1
0
def create_or_update_cluster(
        config_file: str,
        override_min_workers: Optional[int],
        override_max_workers: Optional[int],
        no_restart: bool,
        restart_only: bool,
        yes: bool,
        override_cluster_name: Optional[str] = None,
        no_config_cache: bool = False,
        redirect_command_output: Optional[bool] = False,
        use_login_shells: bool = True,
        no_monitor_on_head: bool = False) -> Dict[str, Any]:
    """Creates or updates an autoscaling Ray cluster from a config json."""
    # no_monitor_on_head is an internal flag used by the Ray K8s operator.
    # If True, prevents autoscaling config sync to the Ray head during cluster
    # creation. See https://github.com/ray-project/ray/pull/13720.
    set_using_login_shells(use_login_shells)
    if not use_login_shells:
        cmd_output_util.set_allow_interactive(False)
    if redirect_command_output is None:
        # Do not redirect by default.
        cmd_output_util.set_output_redirected(False)
    else:
        cmd_output_util.set_output_redirected(redirect_command_output)

    def handle_yaml_error(e):
        cli_logger.error("Cluster config invalid")
        cli_logger.newline()
        cli_logger.error("Failed to load YAML file " + cf.bold("{}"),
                         config_file)
        cli_logger.newline()
        with cli_logger.verbatim_error_ctx("PyYAML error:"):
            cli_logger.error(e)
        cli_logger.abort()

    try:
        config = yaml.safe_load(open(config_file).read())
    except FileNotFoundError:
        cli_logger.abort(
            "Provided cluster configuration file ({}) does not exist",
            cf.bold(config_file))
        raise
    except yaml.parser.ParserError as e:
        handle_yaml_error(e)
        raise
    except yaml.scanner.ScannerError as e:
        handle_yaml_error(e)
        raise
    global_event_system.execute_callback(CreateClusterEvent.up_started,
                                         {"cluster_config": config})

    # todo: validate file_mounts, ssh keys, etc.

    importer = _NODE_PROVIDERS.get(config["provider"]["type"])
    if not importer:
        cli_logger.abort(
            "Unknown provider type " + cf.bold("{}") + "\n"
            "Available providers are: {}", config["provider"]["type"],
            cli_logger.render_list([
                k for k in _NODE_PROVIDERS.keys()
                if _NODE_PROVIDERS[k] is not None
            ]))
        raise NotImplementedError("Unsupported provider {}".format(
            config["provider"]))

    printed_overrides = False

    def handle_cli_override(key, override):
        if override is not None:
            if key in config:
                nonlocal printed_overrides
                printed_overrides = True
                cli_logger.warning(
                    "`{}` override provided on the command line.\n"
                    "  Using " + cf.bold("{}") +
                    cf.dimmed(" [configuration file has " + cf.bold("{}") +
                              "]"), key, override, config[key])
            config[key] = override

    handle_cli_override("min_workers", override_min_workers)
    handle_cli_override("max_workers", override_max_workers)
    handle_cli_override("cluster_name", override_cluster_name)

    if printed_overrides:
        cli_logger.newline()

    cli_logger.labeled_value("Cluster", config["cluster_name"])

    cli_logger.newline()
    config = _bootstrap_config(config, no_config_cache=no_config_cache)

    try_logging_config(config)
    get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
                            override_cluster_name, no_monitor_on_head)
    return config
Exemplo n.º 2
0
def get_or_create_head_node(config: Dict[str, Any],
                            printable_config_file: str,
                            no_restart: bool,
                            restart_only: bool,
                            yes: bool,
                            override_cluster_name: Optional[str],
                            no_monitor_on_head: bool = False,
                            _provider: Optional[NodeProvider] = None,
                            _runner: ModuleType = subprocess) -> None:
    """Create the cluster head node, which in turn creates the workers."""
    global_event_system.execute_callback(
        CreateClusterEvent.cluster_booting_started)
    provider = (_provider or _get_node_provider(config["provider"],
                                                config["cluster_name"]))

    config = copy.deepcopy(config)
    head_node_tags = {
        TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
    }
    nodes = provider.non_terminated_nodes(head_node_tags)
    if len(nodes) > 0:
        head_node = nodes[0]
    else:
        head_node = None

    if not head_node:
        cli_logger.confirm(yes, "No head node found. "
                           "Launching a new cluster.",
                           _abort=True)

    if head_node:
        if restart_only:
            cli_logger.confirm(yes, "Updating cluster configuration and "
                               "restarting the cluster Ray runtime. "
                               "Setup commands will not be run due to `{}`.\n",
                               cf.bold("--restart-only"),
                               _abort=True)
        elif no_restart:
            cli_logger.print(
                "Cluster Ray runtime will not be restarted due "
                "to `{}`.", cf.bold("--no-restart"))
            cli_logger.confirm(yes, "Updating cluster configuration and "
                               "running setup commands.",
                               _abort=True)
        else:
            cli_logger.print(
                "Updating cluster configuration and running full setup.")
            cli_logger.confirm(
                yes,
                cf.bold("Cluster Ray runtime will be restarted."),
                _abort=True)

    cli_logger.newline()
    # TODO(ekl) this logic is duplicated in node_launcher.py (keep in sync)
    head_node_config = copy.deepcopy(config["head_node"])
    head_node_resources = None
    if "head_node_type" in config:
        head_node_type = config["head_node_type"]
        head_node_tags[TAG_RAY_USER_NODE_TYPE] = head_node_type
        head_config = config["available_node_types"][head_node_type]
        head_node_config.update(head_config["node_config"])

        # Not necessary to keep in sync with node_launcher.py
        # Keep in sync with autoscaler.py _node_resources
        head_node_resources = head_config.get("resources")

    launch_hash = hash_launch_conf(head_node_config, config["auth"])
    if head_node is None or provider.node_tags(head_node).get(
            TAG_RAY_LAUNCH_CONFIG) != launch_hash:
        with cli_logger.group("Acquiring an up-to-date head node"):
            global_event_system.execute_callback(
                CreateClusterEvent.acquiring_new_head_node)
            if head_node is not None:
                cli_logger.print(
                    "Currently running head node is out-of-date with "
                    "cluster configuration")
                cli_logger.print(
                    "hash is {}, expected {}",
                    cf.bold(
                        provider.node_tags(head_node).get(
                            TAG_RAY_LAUNCH_CONFIG)), cf.bold(launch_hash))
                cli_logger.confirm(yes, "Relaunching it.", _abort=True)

                provider.terminate_node(head_node)
                cli_logger.print("Terminated head node {}", head_node)

            head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
            head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
                config["cluster_name"])
            provider.create_node(head_node_config, head_node_tags, 1)
            cli_logger.print("Launched a new head node")

            start = time.time()
            head_node = None
            with cli_logger.group("Fetching the new head node"):
                while True:
                    if time.time() - start > 50:
                        cli_logger.abort(
                            "Head node fetch timed out.")  # todo: msg
                        raise RuntimeError("Failed to create head node.")
                    nodes = provider.non_terminated_nodes(head_node_tags)
                    if len(nodes) == 1:
                        head_node = nodes[0]
                        break
                    time.sleep(POLL_INTERVAL)
            cli_logger.newline()

    global_event_system.execute_callback(CreateClusterEvent.head_node_acquired)

    with cli_logger.group(
            "Setting up head node",
            _numbered=("<>", 1, 1),
            # cf.bold(provider.node_tags(head_node)[TAG_RAY_NODE_NAME]),
            _tags=dict()):  # add id, ARN to tags?

        # TODO(ekl) right now we always update the head node even if the
        # hash matches.
        # We could prompt the user for what they want to do here.
        # No need to pass in cluster_sync_files because we use this
        # hash to set up the head node
        (runtime_hash, file_mounts_contents_hash) = hash_runtime_conf(
            config["file_mounts"], None, config)

        if not no_monitor_on_head:
            # Return remote_config_file to avoid prematurely closing it.
            config, remote_config_file = _set_up_config_for_head_node(
                config, provider, no_restart)
            cli_logger.print("Prepared bootstrap config")

        if restart_only:
            setup_commands = []
            ray_start_commands = config["head_start_ray_commands"]
        elif no_restart:
            setup_commands = config["head_setup_commands"]
            ray_start_commands = []
        else:
            setup_commands = config["head_setup_commands"]
            ray_start_commands = config["head_start_ray_commands"]

        if not no_restart:
            warn_about_bad_start_command(ray_start_commands,
                                         no_monitor_on_head)

        updater = NodeUpdaterThread(
            node_id=head_node,
            provider_config=config["provider"],
            provider=provider,
            auth_config=config["auth"],
            cluster_name=config["cluster_name"],
            file_mounts=config["file_mounts"],
            initialization_commands=config["initialization_commands"],
            setup_commands=setup_commands,
            ray_start_commands=ray_start_commands,
            process_runner=_runner,
            runtime_hash=runtime_hash,
            file_mounts_contents_hash=file_mounts_contents_hash,
            is_head_node=True,
            node_resources=head_node_resources,
            rsync_options={
                "rsync_exclude": config.get("rsync_exclude"),
                "rsync_filter": config.get("rsync_filter")
            },
            docker_config=config.get("docker"))
        updater.start()
        updater.join()

        # Refresh the node cache so we see the external ip if available
        provider.non_terminated_nodes(head_node_tags)

        if updater.exitcode != 0:
            # todo: this does not follow the mockup and is not good enough
            cli_logger.abort("Failed to setup head node.")
            sys.exit(1)

    global_event_system.execute_callback(
        CreateClusterEvent.cluster_booting_completed, {
            "head_node_id": head_node,
        })

    monitor_str = "tail -n 100 -f /tmp/ray/session_latest/logs/monitor*"
    if override_cluster_name:
        modifiers = " --cluster-name={}".format(quote(override_cluster_name))
    else:
        modifiers = ""

    cli_logger.newline()
    with cli_logger.group("Useful commands"):
        printable_config_file = os.path.abspath(printable_config_file)
        cli_logger.print("Monitor autoscaling with")
        cli_logger.print(cf.bold("  ray exec {}{} {}"), printable_config_file,
                         modifiers, quote(monitor_str))

        cli_logger.print("Connect to a terminal on the cluster head:")
        cli_logger.print(cf.bold("  ray attach {}{}"), printable_config_file,
                         modifiers)

        remote_shell_str = updater.cmd_runner.remote_shell_command_str()
        cli_logger.print("Get a remote shell to the cluster manually:")
        cli_logger.print("  {}", remote_shell_str.strip())
Exemplo n.º 3
0
def _configure_key_pair(config):
    if "ssh_private_key" in config["auth"]:
        _set_config_info(keypair_src="config")

        # If the key is not configured via the cloudinit
        # UserData, it should be configured via KeyName or
        # else we will risk starting a node that we cannot
        # SSH into:

        if "UserData" not in config["head_node"]:
            cli_logger.doassert(  # todo: verify schema beforehand?
                "KeyName" in config["head_node"],
                "`KeyName` missing for head node.")  # todo: err msg
            assert "KeyName" in config["head_node"]

        if "UserData" not in config["worker_nodes"]:
            cli_logger.doassert(
                "KeyName" in config["worker_nodes"],
                "`KeyName` missing for worker nodes.")  # todo: err msg
            assert "KeyName" in config["worker_nodes"]

        return config

    _set_config_info(keypair_src="default")

    ec2 = _resource("ec2", config)

    # Writing the new ssh key to the filesystem fails if the ~/.ssh
    # directory doesn't already exist.
    os.makedirs(os.path.expanduser("~/.ssh"), exist_ok=True)

    # Try a few times to get or create a good key pair.
    MAX_NUM_KEYS = 30
    for i in range(MAX_NUM_KEYS):

        key_name = config["provider"].get("key_pair", {}).get("key_name")

        key_name, key_path = key_pair(i, config["provider"]["region"],
                                      key_name)
        key = _get_key(key_name, config)

        # Found a good key.
        if key and os.path.exists(key_path):
            break

        # We can safely create a new key.
        if not key and not os.path.exists(key_path):
            cli_logger.verbose(
                "Creating new key pair {} for use as the default.",
                cf.bold(key_name))
            key = ec2.create_key_pair(KeyName=key_name)

            # We need to make sure to _create_ the file with the right
            # permissions. In order to do that we need to change the default
            # os.open behavior to include the mode we want.
            with open(key_path, "w", opener=partial(os.open, mode=0o600)) as f:
                f.write(key.key_material)
            break

    if not key:
        cli_logger.abort(
            "No matching local key file for any of the key pairs in this "
            "account with ids from 0..{}. "
            "Consider deleting some unused keys pairs from your account.",
            key_name)  # todo: err msg
        raise ValueError(
            "No matching local key file for any of the key pairs in this "
            "account with ids from 0..{}. ".format(key_name) +
            "Consider deleting some unused keys pairs from your account.")

    cli_logger.doassert(os.path.exists(key_path), "Private key file " +
                        cf.bold("{}") + " not found for " + cf.bold("{}"),
                        key_path, key_name)  # todo: err msg
    assert os.path.exists(key_path), \
        "Private key file {} not found for {}".format(key_path, key_name)

    config["auth"]["ssh_private_key"] = key_path
    config["head_node"]["KeyName"] = key_name
    config["worker_nodes"]["KeyName"] = key_name

    return config
Exemplo n.º 4
0
    def _create_node(self, node_config, tags, count):
        created_nodes_dict = {}

        tags = to_aws_format(tags)
        conf = node_config.copy()

        tag_pairs = [{
            "Key": TAG_RAY_CLUSTER_NAME,
            "Value": self.cluster_name,
        }]
        for k, v in tags.items():
            tag_pairs.append({
                "Key": k,
                "Value": v,
            })
        tag_specs = [{
            "ResourceType": "instance",
            "Tags": tag_pairs,
        }]
        user_tag_specs = conf.get("TagSpecifications", [])
        # Allow users to add tags and override values of existing
        # tags with their own. This only applies to the resource type
        # "instance". All other resource types are appended to the list of
        # tag specs.
        for user_tag_spec in user_tag_specs:
            if user_tag_spec["ResourceType"] == "instance":
                for user_tag in user_tag_spec["Tags"]:
                    exists = False
                    for tag in tag_specs[0]["Tags"]:
                        if user_tag["Key"] == tag["Key"]:
                            exists = True
                            tag["Value"] = user_tag["Value"]
                            break
                    if not exists:
                        tag_specs[0]["Tags"] += [user_tag]
            else:
                tag_specs += [user_tag_spec]

        # SubnetIds is not a real config key: we must resolve to a
        # single SubnetId before invoking the AWS API.
        subnet_ids = conf.pop("SubnetIds")

        for attempt in range(1, BOTO_CREATE_MAX_RETRIES + 1):
            try:
                subnet_id = subnet_ids[self.subnet_idx % len(subnet_ids)]

                self.subnet_idx += 1
                conf.update({
                    "MinCount":
                    1,
                    "MaxCount":
                    count,
                    "NetworkInterfaces": [{
                        "AssociatePublicIpAddress": True,
                        "DeviceIndex": 0,
                        "SubnetId": subnet_id,
                    }],
                    "TagSpecifications":
                    tag_specs
                })
                _ = conf.pop("SubnetId", None)
                security_group_ids = conf.pop("SecurityGroupIds", None)
                if security_group_ids is not None:
                    conf["NetworkInterfaces"][0]["Groups"] = security_group_ids

                created = self.ec2_fail_fast.create_instances(**conf)
                created_nodes_dict = {n.id: n for n in created}

                # todo: timed?
                # todo: handle plurality?
                with cli_logger.group("Launched {} nodes",
                                      count,
                                      _tags=dict(subnet_id=subnet_id)):
                    for instance in created:
                        # NOTE(maximsmol): This is needed for mocking
                        # boto3 for tests. This is likely a bug in moto
                        # but AWS docs don't seem to say.
                        # You can patch moto/ec2/responses/instances.py
                        # to fix this (add <stateReason> to EC2_RUN_INSTANCES)

                        # The correct value is technically
                        # {"code": "0", "Message": "pending"}
                        state_reason = instance.state_reason or {
                            "Message": "pending"
                        }

                        cli_logger.print("Launched instance {}",
                                         instance.instance_id,
                                         _tags=dict(
                                             state=instance.state["Name"],
                                             info=state_reason["Message"]))
                break
            except botocore.exceptions.ClientError as exc:
                if attempt == BOTO_CREATE_MAX_RETRIES:
                    cli_logger.abort(
                        "Failed to launch instances. Max attempts exceeded.",
                        exc=exc,
                    )
                else:
                    cli_logger.warning(
                        "create_instances: Attempt failed with {}, retrying.",
                        exc)
        return created_nodes_dict
Exemplo n.º 5
0
    def _create_node(self, node_config, tags, count):
        created_nodes_dict = {}

        tags = to_aws_format(tags)
        conf = node_config.copy()

        tag_pairs = [{
            "Key": TAG_RAY_CLUSTER_NAME,
            "Value": self.cluster_name,
        }]
        for k, v in tags.items():
            tag_pairs.append({
                "Key": k,
                "Value": v,
            })
        if CloudwatchHelper.cloudwatch_config_exists(
                self.provider_config, CloudwatchConfigType.AGENT):
            cwa_installed = self._check_ami_cwa_installation(node_config)
            if cwa_installed:
                tag_pairs.extend([{
                    "Key": CLOUDWATCH_AGENT_INSTALLED_TAG,
                    "Value": "True",
                }])
        tag_specs = [{
            "ResourceType": "instance",
            "Tags": tag_pairs,
        }]
        user_tag_specs = conf.get("TagSpecifications", [])
        AWSNodeProvider._merge_tag_specs(tag_specs, user_tag_specs)

        # SubnetIds is not a real config key: we must resolve to a
        # single SubnetId before invoking the AWS API.
        subnet_ids = conf.pop("SubnetIds")

        # update config with min/max node counts and tag specs
        conf.update({
            "MinCount": 1,
            "MaxCount": count,
            "TagSpecifications": tag_specs
        })

        # Try to always launch in the first listed subnet.
        subnet_idx = 0
        cli_logger_tags = {}
        # NOTE: This ensures that we try ALL availability zones before
        # throwing an error.
        max_tries = max(BOTO_CREATE_MAX_RETRIES, len(subnet_ids))
        for attempt in range(1, max_tries + 1):
            try:
                if "NetworkInterfaces" in conf:
                    net_ifs = conf["NetworkInterfaces"]
                    # remove security group IDs previously copied from network
                    # interfaces (create_instances call fails otherwise)
                    conf.pop("SecurityGroupIds", None)
                    cli_logger_tags["network_interfaces"] = str(net_ifs)
                else:
                    subnet_id = subnet_ids[subnet_idx % len(subnet_ids)]
                    conf["SubnetId"] = subnet_id
                    cli_logger_tags["subnet_id"] = subnet_id

                created = self.ec2_fail_fast.create_instances(**conf)
                created_nodes_dict = {n.id: n for n in created}

                # todo: timed?
                # todo: handle plurality?
                with cli_logger.group("Launched {} nodes",
                                      count,
                                      _tags=cli_logger_tags):
                    for instance in created:
                        # NOTE(maximsmol): This is needed for mocking
                        # boto3 for tests. This is likely a bug in moto
                        # but AWS docs don't seem to say.
                        # You can patch moto/ec2/responses/instances.py
                        # to fix this (add <stateReason> to EC2_RUN_INSTANCES)

                        # The correct value is technically
                        # {"code": "0", "Message": "pending"}
                        state_reason = instance.state_reason or {
                            "Message": "pending"
                        }

                        cli_logger.print("Launched instance {}",
                                         instance.instance_id,
                                         _tags=dict(
                                             state=instance.state["Name"],
                                             info=state_reason["Message"]))
                break
            except botocore.exceptions.ClientError as exc:
                if attempt == max_tries:
                    cli_logger.abort(
                        "Failed to launch instances. Max attempts exceeded.",
                        exc=exc,
                    )
                else:
                    cli_logger.warning(
                        "create_instances: Attempt failed with {}, retrying.",
                        exc)

                # Launch failure may be due to instance type availability in
                # the given AZ
                subnet_idx += 1

        return created_nodes_dict
Exemplo n.º 6
0
    def run(self):
        cli_logger.old_info(logger, "{}Updating to {}", self.log_prefix,
                            self.runtime_hash)

        if cmd_output_util.does_allow_interactive(
        ) and cmd_output_util.is_output_redirected():
            # this is most probably a bug since the user has no control
            # over these settings
            msg = ("Output was redirected for an interactive command. "
                   "Either do not pass `--redirect-command-output` "
                   "or also pass in `--use-normal-shells`.")
            cli_logger.abort(msg)
            raise click.ClickException(msg)

        try:
            with LogTimer(self.log_prefix +
                          "Applied config {}".format(self.runtime_hash)):
                self.do_update()
        except Exception as e:
            error_str = str(e)
            if hasattr(e, "cmd"):
                error_str = "(Exit Status {}) {}".format(
                    e.returncode, " ".join(e.cmd))

            self.provider.set_node_tags(
                self.node_id, {TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED})
            cli_logger.error("New status: {}", cf.bold(STATUS_UPDATE_FAILED))

            cli_logger.old_error(logger, "{}Error executing: {}\n",
                                 self.log_prefix, error_str)

            cli_logger.error("!!!")
            if hasattr(e, "cmd"):
                cli_logger.error(
                    "Setup command `{}` failed with exit code {}. stderr:",
                    cf.bold(e.cmd), e.returncode)
            else:
                cli_logger.verbose_error("{}", str(vars(e)))
                # todo: handle this better somehow?
                cli_logger.error("{}", str(e))
            # todo: print stderr here
            cli_logger.error("!!!")
            cli_logger.newline()

            if isinstance(e, click.ClickException):
                # todo: why do we ignore this here
                return
            raise

        tags_to_set = {
            TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
            TAG_RAY_RUNTIME_CONFIG: self.runtime_hash,
        }
        if self.file_mounts_contents_hash is not None:
            tags_to_set[
                TAG_RAY_FILE_MOUNTS_CONTENTS] = self.file_mounts_contents_hash

        self.provider.set_node_tags(self.node_id, tags_to_set)
        cli_logger.labeled_value("New status", STATUS_UP_TO_DATE)

        self.exitcode = 0
Exemplo n.º 7
0
def start(node_ip_address, address, port, redis_password, redis_shard_ports,
          object_manager_port, node_manager_port, gcs_server_port,
          min_worker_port, max_worker_port, worker_port_list, memory,
          object_store_memory, redis_max_memory, num_cpus, num_gpus, resources,
          head, include_dashboard, dashboard_host, dashboard_port, block,
          plasma_directory, autoscaling_config, no_redirect_worker_output,
          no_redirect_output, plasma_store_socket_name, raylet_socket_name,
          temp_dir, java_worker_options, load_code_from_local,
          code_search_path, system_config, lru_evict,
          enable_object_reconstruction, metrics_export_port, log_style,
          log_color, verbose):
    """Start Ray processes manually on the local machine."""
    cli_logger.configure(log_style, log_color, verbose)
    if gcs_server_port and not head:
        raise ValueError(
            "gcs_server_port can be only assigned when you specify --head.")

    # Convert hostnames to numerical IP address.
    if node_ip_address is not None:
        node_ip_address = services.address_to_ip(node_ip_address)

    redis_address = None
    if address is not None:
        (redis_address, redis_address_ip,
         redis_address_port) = services.validate_redis_address(address)
    try:
        resources = json.loads(resources)
    except Exception:
        cli_logger.error("`{}` is not a valid JSON string.",
                         cf.bold("--resources"))
        cli_logger.abort(
            "Valid values look like this: `{}`",
            cf.bold("--resources='\"CustomResource3\": 1, "
                    "\"CustomResource2\": 2}'"))

        raise Exception("Unable to parse the --resources argument using "
                        "json.loads. Try using a format like\n\n"
                        "    --resources='{\"CustomResource1\": 3, "
                        "\"CustomReseource2\": 2}'")

    redirect_worker_output = None if not no_redirect_worker_output else True
    redirect_output = None if not no_redirect_output else True
    ray_params = ray.parameter.RayParams(
        node_ip_address=node_ip_address,
        min_worker_port=min_worker_port,
        max_worker_port=max_worker_port,
        worker_port_list=worker_port_list,
        object_manager_port=object_manager_port,
        node_manager_port=node_manager_port,
        gcs_server_port=gcs_server_port,
        memory=memory,
        object_store_memory=object_store_memory,
        redis_password=redis_password,
        redirect_worker_output=redirect_worker_output,
        redirect_output=redirect_output,
        num_cpus=num_cpus,
        num_gpus=num_gpus,
        resources=resources,
        plasma_directory=plasma_directory,
        huge_pages=False,
        plasma_store_socket_name=plasma_store_socket_name,
        raylet_socket_name=raylet_socket_name,
        temp_dir=temp_dir,
        include_dashboard=include_dashboard,
        dashboard_host=dashboard_host,
        dashboard_port=dashboard_port,
        java_worker_options=java_worker_options,
        load_code_from_local=load_code_from_local,
        code_search_path=code_search_path,
        _system_config=system_config,
        lru_evict=lru_evict,
        enable_object_reconstruction=enable_object_reconstruction,
        metrics_export_port=metrics_export_port)
    if head:
        # Use default if port is none, allocate an available port if port is 0
        if port is None:
            port = ray_constants.DEFAULT_PORT

        if port == 0:
            with socket() as s:
                s.bind(("", 0))
                port = s.getsockname()[1]

        num_redis_shards = None
        # Start Ray on the head node.
        if redis_shard_ports is not None:
            redis_shard_ports = redis_shard_ports.split(",")
            # Infer the number of Redis shards from the ports if the number is
            # not provided.
            num_redis_shards = len(redis_shard_ports)

        if redis_address is not None:
            cli_logger.abort(
                "`{}` starts a new Redis server, `{}` should not be set.",
                cf.bold("--head"), cf.bold("--address"))

            raise Exception("If --head is passed in, a Redis server will be "
                            "started, so a Redis address should not be "
                            "provided.")

        node_ip_address = services.get_node_ip_address()

        # Get the node IP address if one is not provided.
        ray_params.update_if_absent(node_ip_address=node_ip_address)
        cli_logger.labeled_value("Local node IP", ray_params.node_ip_address)
        ray_params.update_if_absent(
            redis_port=port,
            redis_shard_ports=redis_shard_ports,
            redis_max_memory=redis_max_memory,
            num_redis_shards=num_redis_shards,
            redis_max_clients=None,
            autoscaling_config=autoscaling_config,
        )

        # Fail early when starting a new cluster when one is already running
        if address is None:
            default_address = f"{node_ip_address}:{port}"
            redis_addresses = services.find_redis_address(default_address)
            if len(redis_addresses) > 0:
                raise ConnectionError(
                    f"Ray is already running at {default_address}. "
                    f"Please specify a different port using the `--port`"
                    f" command to `ray start`.")

        node = ray.node.Node(
            ray_params, head=True, shutdown_at_exit=block, spawn_reaper=block)
        redis_address = node.redis_address

        # this is a noop if new-style is not set, so the old logger calls
        # are still in place
        cli_logger.newline()
        startup_msg = "Ray runtime started."
        cli_logger.success("-" * len(startup_msg))
        cli_logger.success(startup_msg)
        cli_logger.success("-" * len(startup_msg))
        cli_logger.newline()
        with cli_logger.group("Next steps"):
            cli_logger.print(
                "To connect to this Ray runtime from another node, run")
            cli_logger.print(
                cf.bold("  ray start --address='{}'{}"), redis_address,
                f" --redis-password='******'"
                if redis_password else "")
            cli_logger.newline()
            cli_logger.print("Alternatively, use the following Python code:")
            with cli_logger.indented():
                with cf.with_style("monokai") as c:
                    cli_logger.print("{} ray", c.magenta("import"))
                    cli_logger.print(
                        "ray{}init(address{}{}{})", c.magenta("."),
                        c.magenta("="), c.yellow("'auto'"),
                        ", _redis_password{}{}".format(
                            c.magenta("="),
                            c.yellow("'" + redis_password + "'"))
                        if redis_password else "")
            cli_logger.newline()
            cli_logger.print(
                cf.underlined("If connection fails, check your "
                              "firewall settings and "
                              "network configuration."))
            cli_logger.newline()
            cli_logger.print("To terminate the Ray runtime, run")
            cli_logger.print(cf.bold("  ray stop"))
    else:
        # Start Ray on a non-head node.
        if not (port is None):
            cli_logger.abort("`{}` should not be specified without `{}`.",
                             cf.bold("--port"), cf.bold("--head"))

            raise Exception("If --head is not passed in, --port is not "
                            "allowed.")
        if redis_shard_ports is not None:
            cli_logger.abort("`{}` should not be specified without `{}`.",
                             cf.bold("--redis-shard-ports"), cf.bold("--head"))

            raise Exception("If --head is not passed in, --redis-shard-ports "
                            "is not allowed.")
        if redis_address is None:
            cli_logger.abort("`{}` is required unless starting with `{}`.",
                             cf.bold("--address"), cf.bold("--head"))

            raise Exception("If --head is not passed in, --address must "
                            "be provided.")
        if include_dashboard:
            cli_logger.abort("`{}` should not be specified without `{}`.",
                             cf.bold("--include-dashboard"), cf.bold("--head"))

            raise ValueError(
                "If --head is not passed in, the --include-dashboard"
                "flag is not relevant.")

        # Wait for the Redis server to be started. And throw an exception if we
        # can't connect to it.
        services.wait_for_redis_to_start(
            redis_address_ip, redis_address_port, password=redis_password)

        # Create a Redis client.
        redis_client = services.create_redis_client(
            redis_address, password=redis_password)

        # Check that the version information on this node matches the version
        # information that the cluster was started with.
        services.check_version_info(redis_client)

        # Get the node IP address if one is not provided.
        ray_params.update_if_absent(
            node_ip_address=services.get_node_ip_address(redis_address))

        cli_logger.labeled_value("Local node IP", ray_params.node_ip_address)

        # Check that there aren't already Redis clients with the same IP
        # address connected with this Redis instance. This raises an exception
        # if the Redis server already has clients on this node.
        check_no_existing_redis_clients(ray_params.node_ip_address,
                                        redis_client)
        ray_params.update(redis_address=redis_address)
        node = ray.node.Node(
            ray_params, head=False, shutdown_at_exit=block, spawn_reaper=block)

        cli_logger.newline()
        startup_msg = "Ray runtime started."
        cli_logger.success("-" * len(startup_msg))
        cli_logger.success(startup_msg)
        cli_logger.success("-" * len(startup_msg))
        cli_logger.newline()
        cli_logger.print("To terminate the Ray runtime, run")
        cli_logger.print(cf.bold("  ray stop"))

    if block:
        cli_logger.newline()
        with cli_logger.group(cf.bold("--block")):
            cli_logger.print(
                "This command will now block until terminated by a signal.")
            cli_logger.print(
                "Runing subprocesses are monitored and a message will be "
                "printed if any of them terminate unexpectedly.")

        while True:
            time.sleep(1)
            deceased = node.dead_processes()
            if len(deceased) > 0:
                cli_logger.newline()
                cli_logger.error("Some Ray subprcesses exited unexpectedly:")

                with cli_logger.indented():
                    for process_type, process in deceased:
                        cli_logger.error(
                            "{}",
                            cf.bold(str(process_type)),
                            _tags={"exit code": str(process.returncode)})

                # shutdown_at_exit will handle cleanup.
                cli_logger.newline()
                cli_logger.error("Remaining processes will be killed.")
                sys.exit(1)
Exemplo n.º 8
0
def _configure_subnet(config):
    ec2 = _resource("ec2", config)
    use_internal_ips = config["provider"].get("use_internal_ips", False)

    # If head or worker security group is specified, filter down to subnets
    # belonging to the same VPC as the security group.
    sg_ids = []
    for node_type in config["available_node_types"].values():
        node_config = node_type["node_config"]
        sg_ids.extend(node_config.get("SecurityGroupIds", []))
    if sg_ids:
        vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config)
    else:
        vpc_id_of_sg = None

    try:
        candidate_subnets = ec2.subnets.all()
        if vpc_id_of_sg:
            candidate_subnets = [
                s for s in candidate_subnets if s.vpc_id == vpc_id_of_sg
            ]
        subnets = sorted(
            (s for s in candidate_subnets if s.state == "available" and (
                use_internal_ips or s.map_public_ip_on_launch)),
            reverse=True,  # sort from Z-A
            key=lambda subnet: subnet.availability_zone,
        )
    except botocore.exceptions.ClientError as exc:
        handle_boto_error(exc, "Failed to fetch available subnets from AWS.")
        raise exc

    if not subnets:
        cli_logger.abort(
            "No usable subnets found, try manually creating an instance in "
            "your specified region to populate the list of subnets "
            "and trying this again.\n"
            "Note that the subnet must map public IPs "
            "on instance launch unless you set `use_internal_ips: true` in "
            "the `provider` config.")

    if "availability_zone" in config["provider"]:
        azs = config["provider"]["availability_zone"].split(",")
        subnets = [
            s for az in azs  # Iterate over AZs first to maintain the ordering
            for s in subnets if s.availability_zone == az
        ]
        if not subnets:
            cli_logger.abort(
                "No usable subnets matching availability zone {} found.\n"
                "Choose a different availability zone or try "
                "manually creating an instance in your specified region "
                "to populate the list of subnets and trying this again.",
                config["provider"]["availability_zone"],
            )

    # Use subnets in only one VPC, so that _configure_security_groups only
    # needs to create a security group in this one VPC. Otherwise, we'd need
    # to set up security groups in all of the user's VPCs and set up networking
    # rules to allow traffic between these groups.
    # See https://github.com/ray-project/ray/pull/14868.
    subnet_ids = [
        s.subnet_id for s in subnets if s.vpc_id == subnets[0].vpc_id
    ]
    # map from node type key -> source of SubnetIds field
    subnet_src_info = {}
    _set_config_info(subnet_src=subnet_src_info)
    for key, node_type in config["available_node_types"].items():
        node_config = node_type["node_config"]
        if "SubnetIds" not in node_config:
            subnet_src_info[key] = "default"
            node_config["SubnetIds"] = subnet_ids
        else:
            subnet_src_info[key] = "config"

    return config
Exemplo n.º 9
0
def rsync(config_file: str,
          source: Optional[str],
          target: Optional[str],
          override_cluster_name: Optional[str],
          down: bool,
          ip_address: Optional[str] = None,
          use_internal_ip: bool = False,
          no_config_cache: bool = False,
          all_nodes: bool = False,
          _runner: ModuleType = subprocess) -> None:
    """Rsyncs files.

    Arguments:
        config_file: path to the cluster yaml
        source: source dir
        target: target dir
        override_cluster_name: set the name of the cluster
        down: whether we're syncing remote -> local
        ip_address (str): Address of node. Raise Exception
            if both ip_address and 'all_nodes' are provided.
        use_internal_ip (bool): Whether the provided ip_address is
            public or private.
        all_nodes: whether to sync worker nodes in addition to the head node
    """
    if bool(source) != bool(target):
        cli_logger.abort(
            "Expected either both a source and a target, or neither.")

    assert bool(source) == bool(target), (
        "Must either provide both or neither source and target.")

    if ip_address and all_nodes:
        cli_logger.abort("Cannot provide both ip_address and 'all_nodes'.")

    config = yaml.safe_load(open(config_file).read())
    if override_cluster_name is not None:
        config["cluster_name"] = override_cluster_name
    config = _bootstrap_config(config, no_config_cache=no_config_cache)

    is_file_mount = False
    if source and target:
        for remote_mount in config.get("file_mounts", {}).keys():
            if (source if down else target).startswith(remote_mount):
                is_file_mount = True
                break

    provider = _get_node_provider(config["provider"], config["cluster_name"])

    def rsync_to_node(node_id, is_head_node):
        updater = NodeUpdaterThread(node_id=node_id,
                                    provider_config=config["provider"],
                                    provider=provider,
                                    auth_config=config["auth"],
                                    cluster_name=config["cluster_name"],
                                    file_mounts=config["file_mounts"],
                                    initialization_commands=[],
                                    setup_commands=[],
                                    ray_start_commands=[],
                                    runtime_hash="",
                                    use_internal_ip=use_internal_ip,
                                    process_runner=_runner,
                                    file_mounts_contents_hash="",
                                    is_head_node=is_head_node,
                                    docker_config=config.get("docker"))
        if down:
            rsync = updater.rsync_down
        else:
            rsync = updater.rsync_up

        if source and target:
            # print rsync progress for single file rsync
            cmd_output_util.set_output_redirected(False)
            set_rsync_silent(False)
            rsync(source, target, is_file_mount)
        else:
            updater.sync_file_mounts(rsync)

    try:
        nodes = []
        head_node = _get_head_node(config,
                                   config_file,
                                   override_cluster_name,
                                   create_if_needed=False)
        if ip_address:
            nodes = [
                provider.get_node_id(ip_address,
                                     use_internal_ip=use_internal_ip)
            ]
        else:
            if all_nodes:
                nodes = _get_worker_nodes(config, override_cluster_name)
            nodes += [head_node]

        for node_id in nodes:
            rsync_to_node(node_id, is_head_node=(node_id == head_node))

    finally:
        provider.cleanup()
Exemplo n.º 10
0
def get_or_create_head_node(config: Dict[str, Any],
                            config_file: str,
                            no_restart: bool,
                            restart_only: bool,
                            yes: bool,
                            override_cluster_name: Optional[str],
                            _provider: Optional[NodeProvider] = None,
                            _runner: ModuleType = subprocess) -> None:
    """Create the cluster head node, which in turn creates the workers."""
    provider = (_provider or _get_node_provider(config["provider"],
                                                config["cluster_name"]))

    config = copy.deepcopy(config)
    config_file = os.path.abspath(config_file)
    try:
        head_node_tags = {
            TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
        }
        nodes = provider.non_terminated_nodes(head_node_tags)
        if len(nodes) > 0:
            head_node = nodes[0]
        else:
            head_node = None

        if not head_node:
            cli_logger.confirm(yes, "No head node found. "
                               "Launching a new cluster.",
                               _abort=True)
            cli_logger.old_confirm("This will create a new cluster", yes)
        elif not no_restart:
            cli_logger.old_confirm("This will restart cluster services", yes)

        if head_node:
            if restart_only:
                cli_logger.confirm(
                    yes, "Updating cluster configuration and "
                    "restarting the cluster Ray runtime. "
                    "Setup commands will not be run due to `{}`.\n",
                    cf.bold("--restart-only"),
                    _abort=True)
            elif no_restart:
                cli_logger.print(
                    "Cluster Ray runtime will not be restarted due "
                    "to `{}`.", cf.bold("--no-restart"))
                cli_logger.confirm(yes, "Updating cluster configuration and "
                                   "running setup commands.",
                                   _abort=True)
            else:
                cli_logger.print(
                    "Updating cluster configuration and running full setup.")
                cli_logger.confirm(
                    yes,
                    cf.bold("Cluster Ray runtime will be restarted."),
                    _abort=True)
        cli_logger.newline()

        # TODO(ekl) this logic is duplicated in node_launcher.py (keep in sync)
        head_node_config = copy.deepcopy(config["head_node"])
        if "head_node_type" in config:
            head_node_tags[TAG_RAY_USER_NODE_TYPE] = config["head_node_type"]
            head_node_config.update(config["available_node_types"][
                config["head_node_type"]]["node_config"])

        launch_hash = hash_launch_conf(head_node_config, config["auth"])
        if head_node is None or provider.node_tags(head_node).get(
                TAG_RAY_LAUNCH_CONFIG) != launch_hash:
            with cli_logger.group("Acquiring an up-to-date head node"):
                if head_node is not None:
                    cli_logger.print(
                        "Currently running head node is out-of-date with "
                        "cluster configuration")
                    cli_logger.print(
                        "hash is {}, expected {}",
                        cf.bold(
                            provider.node_tags(head_node).get(
                                TAG_RAY_LAUNCH_CONFIG)), cf.bold(launch_hash))
                    cli_logger.confirm(yes, "Relaunching it.", _abort=True)
                    cli_logger.old_confirm(
                        "Head node config out-of-date. It will be terminated",
                        yes)

                    cli_logger.old_info(
                        logger, "get_or_create_head_node: "
                        "Shutting down outdated head node {}", head_node)

                    provider.terminate_node(head_node)
                    cli_logger.print("Terminated head node {}", head_node)

                cli_logger.old_info(
                    logger,
                    "get_or_create_head_node: Launching new head node...")

                head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
                head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
                    config["cluster_name"])
                provider.create_node(head_node_config, head_node_tags, 1)
                cli_logger.print("Launched a new head node")

                start = time.time()
                head_node = None
                with cli_logger.group("Fetching the new head node"):
                    while True:
                        if time.time() - start > 50:
                            cli_logger.abort(
                                "Head node fetch timed out.")  # todo: msg
                            raise RuntimeError("Failed to create head node.")
                        nodes = provider.non_terminated_nodes(head_node_tags)
                        if len(nodes) == 1:
                            head_node = nodes[0]
                            break
                        time.sleep(POLL_INTERVAL)
                cli_logger.newline()

        with cli_logger.group(
                "Setting up head node",
                _numbered=("<>", 1, 1),
                # cf.bold(provider.node_tags(head_node)[TAG_RAY_NODE_NAME]),
                _tags=dict()):  # add id, ARN to tags?

            # TODO(ekl) right now we always update the head node even if the
            # hash matches.
            # We could prompt the user for what they want to do here.
            # No need to pass in cluster_sync_files because we use this
            # hash to set up the head node
            (runtime_hash, file_mounts_contents_hash) = hash_runtime_conf(
                config["file_mounts"], None, config)

            cli_logger.old_info(
                logger,
                "get_or_create_head_node: Updating files on head node...")

            # Rewrite the auth config so that the head
            # node can update the workers
            remote_config = copy.deepcopy(config)

            # drop proxy options if they exist, otherwise
            # head node won't be able to connect to workers
            remote_config["auth"].pop("ssh_proxy_command", None)

            if "ssh_private_key" in config["auth"]:
                remote_key_path = "~/ray_bootstrap_key.pem"
                remote_config["auth"]["ssh_private_key"] = remote_key_path

            # Adjust for new file locations
            new_mounts = {}
            for remote_path in config["file_mounts"]:
                new_mounts[remote_path] = remote_path
            remote_config["file_mounts"] = new_mounts
            remote_config["no_restart"] = no_restart

            remote_config = provider.prepare_for_head_node(remote_config)

            # Now inject the rewritten config and SSH key into the head node
            remote_config_file = tempfile.NamedTemporaryFile(
                "w", prefix="ray-bootstrap-")
            remote_config_file.write(json.dumps(remote_config))
            remote_config_file.flush()
            config["file_mounts"].update(
                {"~/ray_bootstrap_config.yaml": remote_config_file.name})

            if "ssh_private_key" in config["auth"]:
                config["file_mounts"].update({
                    remote_key_path:
                    config["auth"]["ssh_private_key"],
                })
            cli_logger.print("Prepared bootstrap config")

            if restart_only:
                setup_commands = []
                ray_start_commands = config["head_start_ray_commands"]
            elif no_restart:
                setup_commands = config["head_setup_commands"]
                ray_start_commands = []
            else:
                setup_commands = config["head_setup_commands"]
                ray_start_commands = config["head_start_ray_commands"]

            if not no_restart:
                warn_about_bad_start_command(ray_start_commands)

            updater = NodeUpdaterThread(
                node_id=head_node,
                provider_config=config["provider"],
                provider=provider,
                auth_config=config["auth"],
                cluster_name=config["cluster_name"],
                file_mounts=config["file_mounts"],
                initialization_commands=config["initialization_commands"],
                setup_commands=setup_commands,
                ray_start_commands=ray_start_commands,
                process_runner=_runner,
                runtime_hash=runtime_hash,
                file_mounts_contents_hash=file_mounts_contents_hash,
                is_head_node=True,
                docker_config=config.get("docker"))
            updater.start()
            updater.join()

            # Refresh the node cache so we see the external ip if available
            provider.non_terminated_nodes(head_node_tags)

            if config.get("provider", {}).get("use_internal_ips",
                                              False) is True:
                head_node_ip = provider.internal_ip(head_node)
            else:
                head_node_ip = provider.external_ip(head_node)

            if updater.exitcode != 0:
                # todo: this does not follow the mockup and is not good enough
                cli_logger.abort("Failed to setup head node.")

                cli_logger.old_error(
                    logger, "get_or_create_head_node: "
                    "Updating {} failed", head_node_ip)
                sys.exit(1)

            cli_logger.old_info(
                logger, "get_or_create_head_node: "
                "Head node up-to-date, IP address is: {}", head_node_ip)

        monitor_str = "tail -n 100 -f /tmp/ray/session_latest/logs/monitor*"
        if override_cluster_name:
            modifiers = " --cluster-name={}".format(
                quote(override_cluster_name))
        else:
            modifiers = ""

        if cli_logger.old_style:
            print("To monitor autoscaling activity, you can run:\n\n"
                  "  ray exec {} {}{}\n".format(config_file,
                                                quote(monitor_str), modifiers))
            print("To open a console on the cluster:\n\n"
                  "  ray attach {}{}\n".format(config_file, modifiers))

            print("To get a remote shell to the cluster manually, run:\n\n"
                  "  {}\n".format(
                      updater.cmd_runner.remote_shell_command_str()))

        cli_logger.newline()
        with cli_logger.group("Useful commands"):
            cli_logger.print("Monitor autoscaling with")
            cli_logger.print(cf.bold("  ray exec {}{} {}"), config_file,
                             modifiers, quote(monitor_str))

            cli_logger.print("Connect to a terminal on the cluster head:")
            cli_logger.print(cf.bold("  ray attach {}{}"), config_file,
                             modifiers)

            remote_shell_str = updater.cmd_runner.remote_shell_command_str()
            cli_logger.print("Get a remote shell to the cluster manually:")
            cli_logger.print("  {}", remote_shell_str.strip())
    finally:
        provider.cleanup()
Exemplo n.º 11
0
def create_or_update_cluster(config_file: str,
                             override_min_workers: Optional[int],
                             override_max_workers: Optional[int],
                             no_restart: bool,
                             restart_only: bool,
                             yes: bool,
                             override_cluster_name: Optional[str] = None,
                             no_config_cache: bool = False,
                             redirect_command_output: Optional[bool] = False,
                             use_login_shells: bool = True) -> None:
    """Create or updates an autoscaling Ray cluster from a config json."""
    set_using_login_shells(use_login_shells)
    if not use_login_shells:
        cmd_output_util.set_allow_interactive(False)
    if redirect_command_output is None:
        # Do not redirect by default.
        cmd_output_util.set_output_redirected(False)
    else:
        cmd_output_util.set_output_redirected(redirect_command_output)

    def handle_yaml_error(e):
        cli_logger.error("Cluster config invalid")
        cli_logger.newline()
        cli_logger.error("Failed to load YAML file " + cf.bold("{}"),
                         config_file)
        cli_logger.newline()
        with cli_logger.verbatim_error_ctx("PyYAML error:"):
            cli_logger.error(e)
        cli_logger.abort()

    try:
        config = yaml.safe_load(open(config_file).read())
    except FileNotFoundError:
        cli_logger.abort(
            "Provided cluster configuration file ({}) does not exist",
            cf.bold(config_file))
        raise
    except yaml.parser.ParserError as e:
        handle_yaml_error(e)
        raise
    except yaml.scanner.ScannerError as e:
        handle_yaml_error(e)
        raise

    # todo: validate file_mounts, ssh keys, etc.

    importer = _NODE_PROVIDERS.get(config["provider"]["type"])
    if not importer:
        cli_logger.abort(
            "Unknown provider type " + cf.bold("{}") + "\n"
            "Available providers are: {}", config["provider"]["type"],
            cli_logger.render_list([
                k for k in _NODE_PROVIDERS.keys()
                if _NODE_PROVIDERS[k] is not None
            ]))
        raise NotImplementedError("Unsupported provider {}".format(
            config["provider"]))

    printed_overrides = False

    def handle_cli_override(key, override):
        if override is not None:
            if key in config:
                nonlocal printed_overrides
                printed_overrides = True
                cli_logger.warning(
                    "`{}` override provided on the command line.\n"
                    "  Using " + cf.bold("{}") +
                    cf.dimmed(" [configuration file has " + cf.bold("{}") +
                              "]"), key, override, config[key])
            config[key] = override

    handle_cli_override("min_workers", override_min_workers)
    handle_cli_override("max_workers", override_max_workers)
    handle_cli_override("cluster_name", override_cluster_name)

    if printed_overrides:
        cli_logger.newline()

    cli_logger.labeled_value("Cluster", config["cluster_name"])

    # disable the cli_logger here if needed
    # because it only supports aws
    if config["provider"]["type"] != "aws":
        cli_logger.old_style = True
    cli_logger.newline()
    config = _bootstrap_config(config, no_config_cache=no_config_cache)

    try_logging_config(config)
    get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
                            override_cluster_name)
Exemplo n.º 12
0
def _usable_subnet_ids(
    user_specified_subnets: Optional[List[Any]],
    all_subnets: List[Any],
    azs: Optional[str],
    vpc_id_of_sg: Optional[str],
    use_internal_ips: bool,
    node_type_key: str,
) -> Tuple[List[str], str]:
    """Prunes subnets down to those that meet the following criteria.

    Subnets must be:
    * 'Available' according to AWS.
    * Public, unless `use_internal_ips` is specified.
    * In one of the AZs, if AZs are provided.
    * In the given VPC, if a VPC is specified for Security Groups.

    Returns:
        List[str]: Subnets that are usable.
        str: VPC ID of the first subnet.
    """
    def _are_user_subnets_pruned(current_subnets: List[Any]) -> bool:
        return user_specified_subnets is not None and len(
            current_subnets) != len(user_specified_subnets)

    def _get_pruned_subnets(current_subnets: List[Any]) -> Set[str]:
        current_subnet_ids = {s.subnet_id for s in current_subnets}
        user_specified_subnet_ids = {
            s.subnet_id
            for s in user_specified_subnets
        }
        return user_specified_subnet_ids - current_subnet_ids

    try:
        candidate_subnets = (user_specified_subnets if user_specified_subnets
                             is not None else all_subnets)
        if vpc_id_of_sg:
            candidate_subnets = [
                s for s in candidate_subnets if s.vpc_id == vpc_id_of_sg
            ]
        subnets = sorted(
            (s for s in candidate_subnets if s.state == "available" and (
                use_internal_ips or s.map_public_ip_on_launch)),
            reverse=True,  # sort from Z-A
            key=lambda subnet: subnet.availability_zone,
        )
    except botocore.exceptions.ClientError as exc:
        handle_boto_error(exc, "Failed to fetch available subnets from AWS.")
        raise exc

    if not subnets:
        cli_logger.abort(
            f"No usable subnets found for node type {node_type_key}, try "
            "manually creating an instance in your specified region to "
            "populate the list of subnets and trying this again.\n"
            "Note that the subnet must map public IPs "
            "on instance launch unless you set `use_internal_ips: true` in "
            "the `provider` config.")
    elif _are_user_subnets_pruned(subnets):
        cli_logger.abort(
            f"The specified subnets for node type {node_type_key} are not "
            f"usable: {_get_pruned_subnets(subnets)}")

    if azs is not None:
        azs = [az.strip() for az in azs.split(",")]
        subnets = [
            s for az in azs  # Iterate over AZs first to maintain the ordering
            for s in subnets if s.availability_zone == az
        ]
        if not subnets:
            cli_logger.abort(
                f"No usable subnets matching availability zone {azs} found "
                f"for node type {node_type_key}.\nChoose a different "
                "availability zone or try manually creating an instance in "
                "your specified region to populate the list of subnets and "
                "trying this again.")
        elif _are_user_subnets_pruned(subnets):
            cli_logger.abort(
                f"MISMATCH between specified subnets and Availability Zones! "
                "The following Availability Zones were specified in the "
                f"`provider section`: {azs}.\n The following subnets for node "
                f"type `{node_type_key}` have no matching availability zone: "
                f"{list(_get_pruned_subnets(subnets))}.")

    # Use subnets in only one VPC, so that _configure_security_groups only
    # needs to create a security group in this one VPC. Otherwise, we'd need
    # to set up security groups in all of the user's VPCs and set up networking
    # rules to allow traffic between these groups.
    # See https://github.com/ray-project/ray/pull/14868.
    first_subnet_vpc_id = subnets[0].vpc_id
    subnets = [s.subnet_id for s in subnets if s.vpc_id == subnets[0].vpc_id]
    if _are_user_subnets_pruned(subnets):
        subnet_vpcs = {s.subnet_id: s.vpc_id for s in user_specified_subnets}
        cli_logger.abort(
            f"Subnets specified in more than one VPC for node type `{node_type_key}`! "
            f"Please ensure that all subnets share the same VPC and retry your "
            "request. Subnet VPCs: {}",
            subnet_vpcs,
        )
    return subnets, first_subnet_vpc_id
Exemplo n.º 13
0
def _configure_subnet(config):
    ec2 = _resource("ec2", config)
    use_internal_ips = config["provider"].get("use_internal_ips", False)

    # If head or worker security group is specified, filter down to subnets
    # belonging to the same VPC as the security group.
    sg_ids = (config["head_node"].get("SecurityGroupIds", []) +
              config["worker_nodes"].get("SecurityGroupIds", []))
    if sg_ids:
        vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config)
    else:
        vpc_id_of_sg = None

    try:
        candidate_subnets = ec2.subnets.all()
        if vpc_id_of_sg:
            candidate_subnets = [
                s for s in candidate_subnets if s.vpc_id == vpc_id_of_sg
            ]
        subnets = sorted(
            (s for s in candidate_subnets if s.state == "available" and (
                use_internal_ips or s.map_public_ip_on_launch)),
            reverse=True,  # sort from Z-A
            key=lambda subnet: subnet.availability_zone)
    except botocore.exceptions.ClientError as exc:
        handle_boto_error(exc, "Failed to fetch available subnets from AWS.")
        raise exc

    if not subnets:
        cli_logger.abort(
            "No usable subnets found, try manually creating an instance in "
            "your specified region to populate the list of subnets "
            "and trying this again.\n"
            "Note that the subnet must map public IPs "
            "on instance launch unless you set `use_internal_ips: true` in "
            "the `provider` config.")  # todo: err msg
        raise Exception(
            "No usable subnets found, try manually creating an instance in "
            "your specified region to populate the list of subnets "
            "and trying this again. Note that the subnet must map public IPs "
            "on instance launch unless you set 'use_internal_ips': True in "
            "the 'provider' config.")
    if "availability_zone" in config["provider"]:
        azs = config["provider"]["availability_zone"].split(",")
        subnets = [s for s in subnets if s.availability_zone in azs]
        if not subnets:
            cli_logger.abort(
                "No usable subnets matching availability zone {} found.\n"
                "Choose a different availability zone or try "
                "manually creating an instance in your specified region "
                "to populate the list of subnets and trying this again.",
                config["provider"]["availability_zone"])  # todo: err msg
            raise Exception(
                "No usable subnets matching availability zone {} "
                "found. Choose a different availability zone or try "
                "manually creating an instance in your specified region "
                "to populate the list of subnets and trying this again.".
                format(config["provider"]["availability_zone"]))

    subnet_ids = [s.subnet_id for s in subnets]
    if "SubnetIds" not in config["head_node"]:
        _set_config_info(head_subnet_src="default")
        config["head_node"]["SubnetIds"] = subnet_ids
    else:
        _set_config_info(head_subnet_src="config")

    if "SubnetIds" not in config["worker_nodes"]:
        _set_config_info(workers_subnet_src="default")
        config["worker_nodes"]["SubnetIds"] = subnet_ids
    else:
        _set_config_info(workers_subnet_src="config")

    return config