def _get_vpc_id_or_die(ec2, subnet_id: str): subnets = _get_subnets_or_die(ec2, (subnet_id,)) cli_logger.doassert( len(subnets) == 1, f"Expected 1 subnet with ID `{subnet_id}` but found {len(subnets)}", ) return subnets[0].vpc_id
def _get_vpc_id_of_sg(sg_ids: List[str], config: Dict[str, Any]) -> str: """Returns the VPC id of the security groups with the provided security group ids. Errors if the provided security groups belong to multiple VPCs. Errors if no security group with any of the provided ids is identified. """ sg_ids = list(set(sg_ids)) ec2 = _resource("ec2", config) filters = [{"Name": "group-id", "Values": sg_ids}] security_groups = ec2.security_groups.filter(Filters=filters) vpc_ids = [sg.vpc_id for sg in security_groups] vpc_ids = list(set(vpc_ids)) multiple_vpc_msg = "All security groups specified in the cluster config "\ "should belong to the same VPC." cli_logger.doassert(len(vpc_ids) <= 1, multiple_vpc_msg) assert len(vpc_ids) <= 1, multiple_vpc_msg no_sg_msg = "Failed to detect a security group with id equal to any of "\ "the configured SecurityGroupIds." cli_logger.doassert(len(vpc_ids) > 0, no_sg_msg) assert len(vpc_ids) > 0, no_sg_msg return vpc_ids[0]
def _configure_node_type_from_launch_template(config, node_type): node_cfg = config[node_type] if "LaunchTemplate" not in node_cfg: return config ec2 = _client("ec2", config) kwargs = copy.deepcopy(node_cfg["LaunchTemplate"]) template_version = kwargs.pop("Version", "$Default") kwargs["Versions"] = [template_version] if template_version else [] template = ec2.describe_launch_template_versions(**kwargs) lt_versions = template["LaunchTemplateVersions"] cli_logger.doassert( len(lt_versions) == 1, "Expected to find 1 launch template but found {}".format( len(lt_versions))) assert len(lt_versions) == 1, \ "Expected to find 1 launch template but found {}" \ .format(len(lt_versions)) lt_data = template["LaunchTemplateVersions"][0]["LaunchTemplateData"] # override launch template parameters with explicit node config parameters lt_data.update(node_cfg) # copy all new launch template parameters back to node config node_cfg.update(lt_data) return config
def _configure_iam_role(config): if "IamInstanceProfile" in config["head_node"]: _set_config_info(head_instance_profile_src="config") return config _set_config_info(head_instance_profile_src="default") profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config) if profile is None: cli_logger.verbose( "Creating new IAM instance profile {} for use as the default.", cf.bold(DEFAULT_RAY_INSTANCE_PROFILE)) client = _client("iam", config) client.create_instance_profile( InstanceProfileName=DEFAULT_RAY_INSTANCE_PROFILE) profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config) time.sleep(15) # wait for propagation cli_logger.doassert(profile is not None, "Failed to create instance profile.") # todo: err msg assert profile is not None, "Failed to create instance profile" if not profile.roles: role = _get_role(DEFAULT_RAY_IAM_ROLE, config) if role is None: cli_logger.verbose( "Creating new IAM role {} for " "use as the default instance role.", cf.bold(DEFAULT_RAY_IAM_ROLE)) iam = _resource("iam", config) iam.create_role( RoleName=DEFAULT_RAY_IAM_ROLE, AssumeRolePolicyDocument=json.dumps({ "Statement": [ { "Effect": "Allow", "Principal": { "Service": "ec2.amazonaws.com" }, "Action": "sts:AssumeRole", }, ], })) role = _get_role(DEFAULT_RAY_IAM_ROLE, config) cli_logger.doassert(role is not None, "Failed to create role.") # todo: err msg assert role is not None, "Failed to create role" role.attach_policy( PolicyArn="arn:aws:iam::aws:policy/AmazonEC2FullAccess") role.attach_policy( PolicyArn="arn:aws:iam::aws:policy/AmazonS3FullAccess") profile.add_role(RoleName=role.name) time.sleep(15) # wait for propagation config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn} return config
def _get_subnets_or_die(ec2, subnet_ids: Tuple[str]): subnets = list( ec2.subnets.filter(Filters=[{"Name": "subnet-id", "Values": list(subnet_ids)}]) ) # TODO: better error message cli_logger.doassert( len(subnets) == len(subnet_ids), "Not all subnet IDs found: {}", subnet_ids ) assert len(subnets) == len(subnet_ids), "Subnet ID not found: {}".format(subnet_ids) return subnets
def _get_vpc_id_or_die(ec2, subnet_id): subnet = list( ec2.subnets.filter(Filters=[{ "Name": "subnet-id", "Values": [subnet_id] }])) # TODO: better error message cli_logger.doassert(len(subnet) == 1, "Subnet ID not found: {}", subnet_id) assert len(subnet) == 1, "Subnet ID not found: {}".format(subnet_id) subnet = subnet[0] return subnet.vpc_id
def _create_security_group(config, vpc_id, group_name): client = _client("ec2", config) client.create_security_group( Description="Auto-created security group for Ray workers", GroupName=group_name, VpcId=vpc_id) security_group = _get_security_group(config, vpc_id, group_name) cli_logger.doassert(security_group, "Failed to create security group") # err msg cli_logger.verbose("Created new security group {}", cf.bold(security_group.group_name), _tags=dict(id=security_group.id)) cli_logger.doassert(security_group, "Failed to create security group") # err msg assert security_group, "Failed to create security group" return security_group
def _set_ssh_ip_if_required(self): if self.ssh_ip is not None: return # We assume that this never changes. # I think that's reasonable. deadline = time.time() + AUTOSCALER_NODE_START_WAIT_S with LogTimer(self.log_prefix + "Got IP"): ip = self._wait_for_ip(deadline) cli_logger.doassert(ip is not None, "Could not get node IP.") # todo: msg assert ip is not None, "Unable to find IP of node" self.ssh_ip = ip # This should run before any SSH commands and therefore ensure that # the ControlPath directory exists, allowing SSH to maintain # persistent sessions later on. try: os.makedirs(self.ssh_control_path, mode=0o700, exist_ok=True) except OSError as e: cli_logger.warning("{}", str(e)) # todo: msg
def up(cluster_config_file, min_workers, max_workers, no_restart, restart_only, yes, cluster_name, no_config_cache, redirect_command_output, use_login_shells, log_style, log_color, verbose): """Create or update a Ray cluster.""" cli_logger.configure(log_style, log_color, verbose) if restart_only or no_restart: cli_logger.doassert(restart_only != no_restart, "`{}` is incompatible with `{}`.", cf.bold("--restart-only"), cf.bold("--no-restart")) assert restart_only != no_restart, "Cannot set both 'restart_only' " \ "and 'no_restart' at the same time!" if urllib.parse.urlparse(cluster_config_file).scheme in ("http", "https"): try: response = urllib.request.urlopen(cluster_config_file, timeout=5) content = response.read() file_name = cluster_config_file.split("/")[-1] with open(file_name, "wb") as f: f.write(content) cluster_config_file = file_name except urllib.error.HTTPError as e: cli_logger.warning("{}", str(e)) cli_logger.warning( "Could not download remote cluster configuration file.") cli_logger.old_info(logger, "Error downloading file: ", e) create_or_update_cluster( config_file=cluster_config_file, override_min_workers=min_workers, override_max_workers=max_workers, no_restart=no_restart, restart_only=restart_only, yes=yes, override_cluster_name=cluster_name, no_config_cache=no_config_cache, redirect_command_output=redirect_command_output, use_login_shells=use_login_shells)
cli_logger.print("List: {}", cli_logger.render_list([1, 2, 3])) cli_logger.newline() cli_logger.very_verbose("Very verbose") cli_logger.verbose("Verbose") cli_logger.verbose_warning("Verbose warning") cli_logger.verbose_error("Verbose error") cli_logger.print("Info") cli_logger.success("Success") cli_logger.warning("Warning") cli_logger.error("Error") cli_logger.newline() try: cli_logger.abort("Abort") except Exception: pass try: cli_logger.doassert(False, "Assert") except Exception: pass cli_logger.newline() cli_logger.confirm(True, "example") cli_logger.newline() with cli_logger.indented(): cli_logger.print("Indented") with cli_logger.group("Group"): cli_logger.print("Group contents") with cli_logger.timed("Timed (unimplemented)"): cli_logger.print("Timed contents") with cli_logger.verbatim_error_ctx("Verbtaim error"): cli_logger.print("Error contents")
def _configure_key_pair(config): if "ssh_private_key" in config["auth"]: _set_config_info(keypair_src="config") # If the key is not configured via the cloudinit # UserData, it should be configured via KeyName or # else we will risk starting a node that we cannot # SSH into: if "UserData" not in config["head_node"]: cli_logger.doassert( # todo: verify schema beforehand? "KeyName" in config["head_node"], "`KeyName` missing for head node.") # todo: err msg assert "KeyName" in config["head_node"] if "UserData" not in config["worker_nodes"]: cli_logger.doassert( "KeyName" in config["worker_nodes"], "`KeyName` missing for worker nodes.") # todo: err msg assert "KeyName" in config["worker_nodes"] return config _set_config_info(keypair_src="default") ec2 = _resource("ec2", config) # Writing the new ssh key to the filesystem fails if the ~/.ssh # directory doesn't already exist. os.makedirs(os.path.expanduser("~/.ssh"), exist_ok=True) # Try a few times to get or create a good key pair. MAX_NUM_KEYS = 30 for i in range(MAX_NUM_KEYS): key_name = config["provider"].get("key_pair", {}).get("key_name") key_name, key_path = key_pair(i, config["provider"]["region"], key_name) key = _get_key(key_name, config) # Found a good key. if key and os.path.exists(key_path): break # We can safely create a new key. if not key and not os.path.exists(key_path): cli_logger.verbose( "Creating new key pair {} for use as the default.", cf.bold(key_name)) key = ec2.create_key_pair(KeyName=key_name) # We need to make sure to _create_ the file with the right # permissions. In order to do that we need to change the default # os.open behavior to include the mode we want. with open(key_path, "w", opener=partial(os.open, mode=0o600)) as f: f.write(key.key_material) break if not key: cli_logger.abort( "No matching local key file for any of the key pairs in this " "account with ids from 0..{}. " "Consider deleting some unused keys pairs from your account.", key_name) cli_logger.doassert( os.path.exists(key_path), "Private key file " + cf.bold("{}") + " not found for " + cf.bold("{}"), key_path, key_name) # todo: err msg assert os.path.exists(key_path), \ "Private key file {} not found for {}".format(key_path, key_name) config["auth"]["ssh_private_key"] = key_path config["head_node"]["KeyName"] = key_name config["worker_nodes"]["KeyName"] = key_name return config
def log_to_cli(config): provider_name = _PROVIDER_PRETTY_NAMES.get("aws", None) cli_logger.doassert(provider_name is not None, "Could not find a pretty name for the AWS provider.") with cli_logger.group("{} config", provider_name): def same_everywhere(key): return config["head_node"][key] == config["worker_nodes"][key] def print_info(resource_string, key, head_src_key, workers_src_key, allowed_tags=None, list_value=False): if allowed_tags is None: allowed_tags = ["default"] head_tags = {} workers_tags = {} if _log_info[head_src_key] in allowed_tags: head_tags[_log_info[head_src_key]] = True if _log_info[workers_src_key] in allowed_tags: workers_tags[_log_info[workers_src_key]] = True head_value_str = config["head_node"][key] if list_value: head_value_str = cli_logger.render_list(head_value_str) if same_everywhere(key): cli_logger.labeled_value( # todo: handle plural vs singular? resource_string + " (head & workers)", "{}", head_value_str, _tags=head_tags) else: workers_value_str = config["worker_nodes"][key] if list_value: workers_value_str = cli_logger.render_list( workers_value_str) cli_logger.labeled_value( resource_string + " (head)", "{}", head_value_str, _tags=head_tags) cli_logger.labeled_value( resource_string + " (workers)", "{}", workers_value_str, _tags=workers_tags) tags = {"default": _log_info["head_instance_profile_src"] == "default"} profile_arn = config["head_node"]["IamInstanceProfile"].get("Arn") profile_name = _arn_to_name(profile_arn) \ if profile_arn \ else config["head_node"]["IamInstanceProfile"]["Name"] cli_logger.labeled_value("IAM Profile", "{}", profile_name, _tags=tags) if ("KeyName" in config["head_node"] and "KeyName" in config["worker_nodes"]): print_info("EC2 Key pair", "KeyName", "keypair_src", "keypair_src") print_info( "VPC Subnets", "SubnetIds", "head_subnet_src", "workers_subnet_src", list_value=True) print_info( "EC2 Security groups", "SecurityGroupIds", "head_security_group_src", "workers_security_group_src", list_value=True) print_info( "EC2 AMI", "ImageId", "head_ami_src", "workers_ami_src", allowed_tags=["dlami"]) cli_logger.newline()
def submit(cluster_config_file, screen, tmux, stop, start, cluster_name, no_config_cache, port_forward, script, args, script_args, log_style, log_color, verbose): """Uploads and runs a script on the specified cluster. The script is automatically synced to the following location: os.path.join("~", os.path.basename(script)) Example: >>> ray submit [CLUSTER.YAML] experiment.py -- --smoke-test """ cli_logger.configure(log_style, log_color, verbose) cli_logger.doassert(not (screen and tmux), "`{}` and `{}` are incompatible.", cf.bold("--screen"), cf.bold("--tmux")) cli_logger.doassert( not (script_args and args), "`{0}` and `{1}` are incompatible. Use only `{1}`.\n" "Example: `{2}`", cf.bold("--args"), cf.bold("-- <args ...>"), cf.bold("ray submit script.py -- --arg=123 --flag")) assert not (screen and tmux), "Can specify only one of `screen` or `tmux`." assert not (script_args and args), "Use -- --arg1 --arg2 for script args." if args: cli_logger.warning( "`{}` is deprecated and will be removed in the future.", cf.bold("--args")) cli_logger.warning("Use `{}` instead. Example: `{}`.", cf.bold("-- <args ...>"), cf.bold("ray submit script.py -- --arg=123 --flag")) cli_logger.newline() if start: create_or_update_cluster( config_file=cluster_config_file, override_min_workers=None, override_max_workers=None, no_restart=False, restart_only=False, yes=True, override_cluster_name=cluster_name, no_config_cache=no_config_cache, redirect_command_output=False, use_login_shells=True) target = os.path.basename(script) target = os.path.join("~", target) rsync( cluster_config_file, script, target, cluster_name, no_config_cache=no_config_cache, down=False) command_parts = ["python", target] if script_args: command_parts += list(script_args) elif args is not None: command_parts += [args] port_forward = [(port, port) for port in list(port_forward)] cmd = " ".join(command_parts) exec_cluster( cluster_config_file, cmd=cmd, run_env="docker", screen=screen, tmux=tmux, stop=stop, start=False, override_cluster_name=cluster_name, no_config_cache=no_config_cache, port_forward=port_forward)
def _configure_iam_role(config): head_node_type = config["head_node_type"] head_node_config = config["available_node_types"][head_node_type]["node_config"] if "IamInstanceProfile" in head_node_config: _set_config_info(head_instance_profile_src="config") return config _set_config_info(head_instance_profile_src="default") instance_profile_name = cwh.resolve_instance_profile_name( config["provider"], DEFAULT_RAY_INSTANCE_PROFILE, ) profile = _get_instance_profile(instance_profile_name, config) if profile is None: cli_logger.verbose( "Creating new IAM instance profile {} for use as the default.", cf.bold(instance_profile_name), ) client = _client("iam", config) client.create_instance_profile(InstanceProfileName=instance_profile_name) profile = _get_instance_profile(instance_profile_name, config) time.sleep(15) # wait for propagation cli_logger.doassert( profile is not None, "Failed to create instance profile." ) # todo: err msg assert profile is not None, "Failed to create instance profile" if not profile.roles: role_name = cwh.resolve_iam_role_name(config["provider"], DEFAULT_RAY_IAM_ROLE) role = _get_role(role_name, config) if role is None: cli_logger.verbose( "Creating new IAM role {} for use as the default instance role.", cf.bold(role_name), ) iam = _resource("iam", config) policy_doc = { "Statement": [ { "Effect": "Allow", "Principal": {"Service": "ec2.amazonaws.com"}, "Action": "sts:AssumeRole", }, ] } attach_policy_arns = cwh.resolve_policy_arns( config["provider"], iam, [ "arn:aws:iam::aws:policy/AmazonEC2FullAccess", "arn:aws:iam::aws:policy/AmazonS3FullAccess", ], ) iam.create_role( RoleName=role_name, AssumeRolePolicyDocument=json.dumps(policy_doc) ) role = _get_role(role_name, config) cli_logger.doassert( role is not None, "Failed to create role." ) # todo: err msg assert role is not None, "Failed to create role" for policy_arn in attach_policy_arns: role.attach_policy(PolicyArn=policy_arn) profile.add_role(RoleName=role.name) time.sleep(15) # wait for propagation # Add IAM role to "head_node" field so that it is applied only to # the head node -- not to workers with the same node type as the head. config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn} return config
def log_to_cli(config: Dict[str, Any]) -> None: provider_name = _PROVIDER_PRETTY_NAMES.get("aws", None) cli_logger.doassert( provider_name is not None, "Could not find a pretty name for the AWS provider." ) head_node_type = config["head_node_type"] head_node_config = config["available_node_types"][head_node_type]["node_config"] with cli_logger.group("{} config", provider_name): def print_info( resource_string: str, key: str, src_key: str, allowed_tags: Optional[List[str]] = None, list_value: bool = False, ) -> None: if allowed_tags is None: allowed_tags = ["default"] node_tags = {} # set of configurations corresponding to `key` unique_settings = set() for node_type_key, node_type in config["available_node_types"].items(): node_tags[node_type_key] = {} tag = _log_info[src_key][node_type_key] if tag in allowed_tags: node_tags[node_type_key][tag] = True setting = node_type["node_config"].get(key) if list_value: unique_settings.add(tuple(setting)) else: unique_settings.add(setting) head_value_str = head_node_config[key] if list_value: head_value_str = cli_logger.render_list(head_value_str) if len(unique_settings) == 1: # all node types are configured the same, condense # log output cli_logger.labeled_value( resource_string + " (all available node types)", "{}", head_value_str, _tags=node_tags[config["head_node_type"]], ) else: # do head node type first cli_logger.labeled_value( resource_string + f" ({head_node_type})", "{}", head_value_str, _tags=node_tags[head_node_type], ) # go through remaining types for node_type_key, node_type in config["available_node_types"].items(): if node_type_key == head_node_type: continue workers_value_str = node_type["node_config"][key] if list_value: workers_value_str = cli_logger.render_list(workers_value_str) cli_logger.labeled_value( resource_string + f" ({node_type_key})", "{}", workers_value_str, _tags=node_tags[node_type_key], ) tags = {"default": _log_info["head_instance_profile_src"] == "default"} # head_node_config is the head_node_type's config, # config["head_node"] is a field that gets applied only to the actual # head node (and not workers of the head's node_type) assert ( "IamInstanceProfile" in head_node_config or "IamInstanceProfile" in config["head_node"] ) if "IamInstanceProfile" in head_node_config: # If the user manually configured the role we're here. IamProfile = head_node_config["IamInstanceProfile"] elif "IamInstanceProfile" in config["head_node"]: # If we filled the default IAM role, we're here. IamProfile = config["head_node"]["IamInstanceProfile"] profile_arn = IamProfile.get("Arn") profile_name = _arn_to_name(profile_arn) if profile_arn else IamProfile["Name"] cli_logger.labeled_value("IAM Profile", "{}", profile_name, _tags=tags) if all( "KeyName" in node_type["node_config"] for node_type in config["available_node_types"].values() ): print_info("EC2 Key pair", "KeyName", "keypair_src") print_info("VPC Subnets", "SubnetIds", "subnet_src", list_value=True) print_info( "EC2 Security groups", "SecurityGroupIds", "security_group_src", list_value=True, ) print_info("EC2 AMI", "ImageId", "ami_src", allowed_tags=["dlami"]) cli_logger.newline()