Example #1
0
def _get_vpc_id_or_die(ec2, subnet_id: str):
    subnets = _get_subnets_or_die(ec2, (subnet_id,))
    cli_logger.doassert(
        len(subnets) == 1,
        f"Expected 1 subnet with ID `{subnet_id}` but found {len(subnets)}",
    )
    return subnets[0].vpc_id
Example #2
0
def _get_vpc_id_of_sg(sg_ids: List[str], config: Dict[str, Any]) -> str:
    """Returns the VPC id of the security groups with the provided security
    group ids.

    Errors if the provided security groups belong to multiple VPCs.
    Errors if no security group with any of the provided ids is identified.
    """
    sg_ids = list(set(sg_ids))

    ec2 = _resource("ec2", config)
    filters = [{"Name": "group-id", "Values": sg_ids}]
    security_groups = ec2.security_groups.filter(Filters=filters)
    vpc_ids = [sg.vpc_id for sg in security_groups]
    vpc_ids = list(set(vpc_ids))

    multiple_vpc_msg = "All security groups specified in the cluster config "\
        "should belong to the same VPC."
    cli_logger.doassert(len(vpc_ids) <= 1, multiple_vpc_msg)
    assert len(vpc_ids) <= 1, multiple_vpc_msg

    no_sg_msg = "Failed to detect a security group with id equal to any of "\
        "the configured SecurityGroupIds."
    cli_logger.doassert(len(vpc_ids) > 0, no_sg_msg)
    assert len(vpc_ids) > 0, no_sg_msg

    return vpc_ids[0]
Example #3
0
def _configure_node_type_from_launch_template(config, node_type):
    node_cfg = config[node_type]
    if "LaunchTemplate" not in node_cfg:
        return config

    ec2 = _client("ec2", config)
    kwargs = copy.deepcopy(node_cfg["LaunchTemplate"])
    template_version = kwargs.pop("Version", "$Default")
    kwargs["Versions"] = [template_version] if template_version else []

    template = ec2.describe_launch_template_versions(**kwargs)
    lt_versions = template["LaunchTemplateVersions"]
    cli_logger.doassert(
        len(lt_versions) == 1,
        "Expected to find 1 launch template but found {}".format(
            len(lt_versions)))
    assert len(lt_versions) == 1, \
        "Expected to find 1 launch template but found {}" \
        .format(len(lt_versions))
    lt_data = template["LaunchTemplateVersions"][0]["LaunchTemplateData"]

    # override launch template parameters with explicit node config parameters
    lt_data.update(node_cfg)
    # copy all new launch template parameters back to node config
    node_cfg.update(lt_data)

    return config
Example #4
0
def _configure_iam_role(config):
    if "IamInstanceProfile" in config["head_node"]:
        _set_config_info(head_instance_profile_src="config")
        return config
    _set_config_info(head_instance_profile_src="default")

    profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)

    if profile is None:
        cli_logger.verbose(
            "Creating new IAM instance profile {} for use as the default.",
            cf.bold(DEFAULT_RAY_INSTANCE_PROFILE))
        client = _client("iam", config)
        client.create_instance_profile(
            InstanceProfileName=DEFAULT_RAY_INSTANCE_PROFILE)
        profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)
        time.sleep(15)  # wait for propagation

    cli_logger.doassert(profile is not None,
                        "Failed to create instance profile.")  # todo: err msg
    assert profile is not None, "Failed to create instance profile"

    if not profile.roles:
        role = _get_role(DEFAULT_RAY_IAM_ROLE, config)
        if role is None:
            cli_logger.verbose(
                "Creating new IAM role {} for "
                "use as the default instance role.",
                cf.bold(DEFAULT_RAY_IAM_ROLE))
            iam = _resource("iam", config)
            iam.create_role(
                RoleName=DEFAULT_RAY_IAM_ROLE,
                AssumeRolePolicyDocument=json.dumps({
                    "Statement": [
                        {
                            "Effect": "Allow",
                            "Principal": {
                                "Service": "ec2.amazonaws.com"
                            },
                            "Action": "sts:AssumeRole",
                        },
                    ],
                }))
            role = _get_role(DEFAULT_RAY_IAM_ROLE, config)

            cli_logger.doassert(role is not None,
                                "Failed to create role.")  # todo: err msg
            assert role is not None, "Failed to create role"
        role.attach_policy(
            PolicyArn="arn:aws:iam::aws:policy/AmazonEC2FullAccess")
        role.attach_policy(
            PolicyArn="arn:aws:iam::aws:policy/AmazonS3FullAccess")
        profile.add_role(RoleName=role.name)
        time.sleep(15)  # wait for propagation

    config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}

    return config
Example #5
0
def _get_subnets_or_die(ec2, subnet_ids: Tuple[str]):
    subnets = list(
        ec2.subnets.filter(Filters=[{"Name": "subnet-id", "Values": list(subnet_ids)}])
    )

    # TODO: better error message
    cli_logger.doassert(
        len(subnets) == len(subnet_ids), "Not all subnet IDs found: {}", subnet_ids
    )
    assert len(subnets) == len(subnet_ids), "Subnet ID not found: {}".format(subnet_ids)
    return subnets
Example #6
0
def _get_vpc_id_or_die(ec2, subnet_id):
    subnet = list(
        ec2.subnets.filter(Filters=[{
            "Name": "subnet-id",
            "Values": [subnet_id]
        }]))

    # TODO: better error message
    cli_logger.doassert(len(subnet) == 1, "Subnet ID not found: {}", subnet_id)
    assert len(subnet) == 1, "Subnet ID not found: {}".format(subnet_id)
    subnet = subnet[0]
    return subnet.vpc_id
Example #7
0
def _create_security_group(config, vpc_id, group_name):
    client = _client("ec2", config)
    client.create_security_group(
        Description="Auto-created security group for Ray workers",
        GroupName=group_name,
        VpcId=vpc_id)
    security_group = _get_security_group(config, vpc_id, group_name)
    cli_logger.doassert(security_group,
                        "Failed to create security group")  # err msg

    cli_logger.verbose("Created new security group {}",
                       cf.bold(security_group.group_name),
                       _tags=dict(id=security_group.id))
    cli_logger.doassert(security_group,
                        "Failed to create security group")  # err msg
    assert security_group, "Failed to create security group"
    return security_group
Example #8
0
    def _set_ssh_ip_if_required(self):
        if self.ssh_ip is not None:
            return

        # We assume that this never changes.
        #   I think that's reasonable.
        deadline = time.time() + AUTOSCALER_NODE_START_WAIT_S
        with LogTimer(self.log_prefix + "Got IP"):
            ip = self._wait_for_ip(deadline)

            cli_logger.doassert(ip is not None, "Could not get node IP.")  # todo: msg
            assert ip is not None, "Unable to find IP of node"

        self.ssh_ip = ip

        # This should run before any SSH commands and therefore ensure that
        #   the ControlPath directory exists, allowing SSH to maintain
        #   persistent sessions later on.
        try:
            os.makedirs(self.ssh_control_path, mode=0o700, exist_ok=True)
        except OSError as e:
            cli_logger.warning("{}", str(e))  # todo: msg
Example #9
0
def up(cluster_config_file, min_workers, max_workers, no_restart, restart_only,
       yes, cluster_name, no_config_cache, redirect_command_output,
       use_login_shells, log_style, log_color, verbose):
    """Create or update a Ray cluster."""
    cli_logger.configure(log_style, log_color, verbose)

    if restart_only or no_restart:
        cli_logger.doassert(restart_only != no_restart,
                            "`{}` is incompatible with `{}`.",
                            cf.bold("--restart-only"), cf.bold("--no-restart"))
        assert restart_only != no_restart, "Cannot set both 'restart_only' " \
            "and 'no_restart' at the same time!"

    if urllib.parse.urlparse(cluster_config_file).scheme in ("http", "https"):
        try:
            response = urllib.request.urlopen(cluster_config_file, timeout=5)
            content = response.read()
            file_name = cluster_config_file.split("/")[-1]
            with open(file_name, "wb") as f:
                f.write(content)
            cluster_config_file = file_name
        except urllib.error.HTTPError as e:
            cli_logger.warning("{}", str(e))
            cli_logger.warning(
                "Could not download remote cluster configuration file.")
            cli_logger.old_info(logger, "Error downloading file: ", e)
    create_or_update_cluster(
        config_file=cluster_config_file,
        override_min_workers=min_workers,
        override_max_workers=max_workers,
        no_restart=no_restart,
        restart_only=restart_only,
        yes=yes,
        override_cluster_name=cluster_name,
        no_config_cache=no_config_cache,
        redirect_command_output=redirect_command_output,
        use_login_shells=use_login_shells)
Example #10
0
cli_logger.print("List: {}", cli_logger.render_list([1, 2, 3]))
cli_logger.newline()
cli_logger.very_verbose("Very verbose")
cli_logger.verbose("Verbose")
cli_logger.verbose_warning("Verbose warning")
cli_logger.verbose_error("Verbose error")
cli_logger.print("Info")
cli_logger.success("Success")
cli_logger.warning("Warning")
cli_logger.error("Error")
cli_logger.newline()
try:
    cli_logger.abort("Abort")
except Exception:
    pass
try:
    cli_logger.doassert(False, "Assert")
except Exception:
    pass
cli_logger.newline()
cli_logger.confirm(True, "example")
cli_logger.newline()
with cli_logger.indented():
    cli_logger.print("Indented")
with cli_logger.group("Group"):
    cli_logger.print("Group contents")
with cli_logger.timed("Timed (unimplemented)"):
    cli_logger.print("Timed contents")
with cli_logger.verbatim_error_ctx("Verbtaim error"):
    cli_logger.print("Error contents")
Example #11
0
def _configure_key_pair(config):
    if "ssh_private_key" in config["auth"]:
        _set_config_info(keypair_src="config")

        # If the key is not configured via the cloudinit
        # UserData, it should be configured via KeyName or
        # else we will risk starting a node that we cannot
        # SSH into:

        if "UserData" not in config["head_node"]:
            cli_logger.doassert(  # todo: verify schema beforehand?
                "KeyName" in config["head_node"],
                "`KeyName` missing for head node.")  # todo: err msg
            assert "KeyName" in config["head_node"]

        if "UserData" not in config["worker_nodes"]:
            cli_logger.doassert(
                "KeyName" in config["worker_nodes"],
                "`KeyName` missing for worker nodes.")  # todo: err msg
            assert "KeyName" in config["worker_nodes"]

        return config

    _set_config_info(keypair_src="default")

    ec2 = _resource("ec2", config)

    # Writing the new ssh key to the filesystem fails if the ~/.ssh
    # directory doesn't already exist.
    os.makedirs(os.path.expanduser("~/.ssh"), exist_ok=True)

    # Try a few times to get or create a good key pair.
    MAX_NUM_KEYS = 30
    for i in range(MAX_NUM_KEYS):

        key_name = config["provider"].get("key_pair", {}).get("key_name")

        key_name, key_path = key_pair(i, config["provider"]["region"],
                                      key_name)
        key = _get_key(key_name, config)

        # Found a good key.
        if key and os.path.exists(key_path):
            break

        # We can safely create a new key.
        if not key and not os.path.exists(key_path):
            cli_logger.verbose(
                "Creating new key pair {} for use as the default.",
                cf.bold(key_name))
            key = ec2.create_key_pair(KeyName=key_name)

            # We need to make sure to _create_ the file with the right
            # permissions. In order to do that we need to change the default
            # os.open behavior to include the mode we want.
            with open(key_path, "w", opener=partial(os.open, mode=0o600)) as f:
                f.write(key.key_material)
            break

    if not key:
        cli_logger.abort(
            "No matching local key file for any of the key pairs in this "
            "account with ids from 0..{}. "
            "Consider deleting some unused keys pairs from your account.",
            key_name)

    cli_logger.doassert(
        os.path.exists(key_path), "Private key file " + cf.bold("{}") +
        " not found for " + cf.bold("{}"), key_path, key_name)  # todo: err msg
    assert os.path.exists(key_path), \
        "Private key file {} not found for {}".format(key_path, key_name)

    config["auth"]["ssh_private_key"] = key_path
    config["head_node"]["KeyName"] = key_name
    config["worker_nodes"]["KeyName"] = key_name

    return config
Example #12
0
def log_to_cli(config):
    provider_name = _PROVIDER_PRETTY_NAMES.get("aws", None)

    cli_logger.doassert(provider_name is not None,
                        "Could not find a pretty name for the AWS provider.")

    with cli_logger.group("{} config", provider_name):

        def same_everywhere(key):
            return config["head_node"][key] == config["worker_nodes"][key]

        def print_info(resource_string,
                       key,
                       head_src_key,
                       workers_src_key,
                       allowed_tags=None,
                       list_value=False):
            if allowed_tags is None:
                allowed_tags = ["default"]

            head_tags = {}
            workers_tags = {}

            if _log_info[head_src_key] in allowed_tags:
                head_tags[_log_info[head_src_key]] = True
            if _log_info[workers_src_key] in allowed_tags:
                workers_tags[_log_info[workers_src_key]] = True

            head_value_str = config["head_node"][key]
            if list_value:
                head_value_str = cli_logger.render_list(head_value_str)

            if same_everywhere(key):
                cli_logger.labeled_value(  # todo: handle plural vs singular?
                    resource_string + " (head & workers)",
                    "{}",
                    head_value_str,
                    _tags=head_tags)
            else:
                workers_value_str = config["worker_nodes"][key]
                if list_value:
                    workers_value_str = cli_logger.render_list(
                        workers_value_str)

                cli_logger.labeled_value(
                    resource_string + " (head)",
                    "{}",
                    head_value_str,
                    _tags=head_tags)
                cli_logger.labeled_value(
                    resource_string + " (workers)",
                    "{}",
                    workers_value_str,
                    _tags=workers_tags)

        tags = {"default": _log_info["head_instance_profile_src"] == "default"}
        profile_arn = config["head_node"]["IamInstanceProfile"].get("Arn")
        profile_name = _arn_to_name(profile_arn) \
            if profile_arn \
            else config["head_node"]["IamInstanceProfile"]["Name"]
        cli_logger.labeled_value("IAM Profile", "{}", profile_name, _tags=tags)

        if ("KeyName" in config["head_node"]
                and "KeyName" in config["worker_nodes"]):
            print_info("EC2 Key pair", "KeyName", "keypair_src", "keypair_src")

        print_info(
            "VPC Subnets",
            "SubnetIds",
            "head_subnet_src",
            "workers_subnet_src",
            list_value=True)
        print_info(
            "EC2 Security groups",
            "SecurityGroupIds",
            "head_security_group_src",
            "workers_security_group_src",
            list_value=True)
        print_info(
            "EC2 AMI",
            "ImageId",
            "head_ami_src",
            "workers_ami_src",
            allowed_tags=["dlami"])

    cli_logger.newline()
Example #13
0
def submit(cluster_config_file, screen, tmux, stop, start, cluster_name,
           no_config_cache, port_forward, script, args, script_args, log_style,
           log_color, verbose):
    """Uploads and runs a script on the specified cluster.

    The script is automatically synced to the following location:

        os.path.join("~", os.path.basename(script))

    Example:
        >>> ray submit [CLUSTER.YAML] experiment.py -- --smoke-test
    """
    cli_logger.configure(log_style, log_color, verbose)

    cli_logger.doassert(not (screen and tmux),
                        "`{}` and `{}` are incompatible.", cf.bold("--screen"),
                        cf.bold("--tmux"))
    cli_logger.doassert(
        not (script_args and args),
        "`{0}` and `{1}` are incompatible. Use only `{1}`.\n"
        "Example: `{2}`", cf.bold("--args"), cf.bold("-- <args ...>"),
        cf.bold("ray submit script.py -- --arg=123 --flag"))

    assert not (screen and tmux), "Can specify only one of `screen` or `tmux`."
    assert not (script_args and args), "Use -- --arg1 --arg2 for script args."

    if args:
        cli_logger.warning(
            "`{}` is deprecated and will be removed in the future.",
            cf.bold("--args"))
        cli_logger.warning("Use `{}` instead. Example: `{}`.",
                           cf.bold("-- <args ...>"),
                           cf.bold("ray submit script.py -- --arg=123 --flag"))
        cli_logger.newline()

    if start:
        create_or_update_cluster(
            config_file=cluster_config_file,
            override_min_workers=None,
            override_max_workers=None,
            no_restart=False,
            restart_only=False,
            yes=True,
            override_cluster_name=cluster_name,
            no_config_cache=no_config_cache,
            redirect_command_output=False,
            use_login_shells=True)
    target = os.path.basename(script)
    target = os.path.join("~", target)
    rsync(
        cluster_config_file,
        script,
        target,
        cluster_name,
        no_config_cache=no_config_cache,
        down=False)

    command_parts = ["python", target]
    if script_args:
        command_parts += list(script_args)
    elif args is not None:
        command_parts += [args]

    port_forward = [(port, port) for port in list(port_forward)]
    cmd = " ".join(command_parts)
    exec_cluster(
        cluster_config_file,
        cmd=cmd,
        run_env="docker",
        screen=screen,
        tmux=tmux,
        stop=stop,
        start=False,
        override_cluster_name=cluster_name,
        no_config_cache=no_config_cache,
        port_forward=port_forward)
Example #14
0
def _configure_iam_role(config):
    head_node_type = config["head_node_type"]
    head_node_config = config["available_node_types"][head_node_type]["node_config"]
    if "IamInstanceProfile" in head_node_config:
        _set_config_info(head_instance_profile_src="config")
        return config
    _set_config_info(head_instance_profile_src="default")

    instance_profile_name = cwh.resolve_instance_profile_name(
        config["provider"],
        DEFAULT_RAY_INSTANCE_PROFILE,
    )
    profile = _get_instance_profile(instance_profile_name, config)

    if profile is None:
        cli_logger.verbose(
            "Creating new IAM instance profile {} for use as the default.",
            cf.bold(instance_profile_name),
        )
        client = _client("iam", config)
        client.create_instance_profile(InstanceProfileName=instance_profile_name)
        profile = _get_instance_profile(instance_profile_name, config)
        time.sleep(15)  # wait for propagation

    cli_logger.doassert(
        profile is not None, "Failed to create instance profile."
    )  # todo: err msg
    assert profile is not None, "Failed to create instance profile"

    if not profile.roles:
        role_name = cwh.resolve_iam_role_name(config["provider"], DEFAULT_RAY_IAM_ROLE)
        role = _get_role(role_name, config)
        if role is None:
            cli_logger.verbose(
                "Creating new IAM role {} for use as the default instance role.",
                cf.bold(role_name),
            )
            iam = _resource("iam", config)
            policy_doc = {
                "Statement": [
                    {
                        "Effect": "Allow",
                        "Principal": {"Service": "ec2.amazonaws.com"},
                        "Action": "sts:AssumeRole",
                    },
                ]
            }
            attach_policy_arns = cwh.resolve_policy_arns(
                config["provider"],
                iam,
                [
                    "arn:aws:iam::aws:policy/AmazonEC2FullAccess",
                    "arn:aws:iam::aws:policy/AmazonS3FullAccess",
                ],
            )

            iam.create_role(
                RoleName=role_name, AssumeRolePolicyDocument=json.dumps(policy_doc)
            )
            role = _get_role(role_name, config)
            cli_logger.doassert(
                role is not None, "Failed to create role."
            )  # todo: err msg

            assert role is not None, "Failed to create role"

            for policy_arn in attach_policy_arns:
                role.attach_policy(PolicyArn=policy_arn)

        profile.add_role(RoleName=role.name)
        time.sleep(15)  # wait for propagation
    # Add IAM role to "head_node" field so that it is applied only to
    # the head node -- not to workers with the same node type as the head.
    config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}

    return config
Example #15
0
def log_to_cli(config: Dict[str, Any]) -> None:
    provider_name = _PROVIDER_PRETTY_NAMES.get("aws", None)

    cli_logger.doassert(
        provider_name is not None, "Could not find a pretty name for the AWS provider."
    )

    head_node_type = config["head_node_type"]
    head_node_config = config["available_node_types"][head_node_type]["node_config"]

    with cli_logger.group("{} config", provider_name):

        def print_info(
            resource_string: str,
            key: str,
            src_key: str,
            allowed_tags: Optional[List[str]] = None,
            list_value: bool = False,
        ) -> None:
            if allowed_tags is None:
                allowed_tags = ["default"]

            node_tags = {}

            # set of configurations corresponding to `key`
            unique_settings = set()

            for node_type_key, node_type in config["available_node_types"].items():
                node_tags[node_type_key] = {}
                tag = _log_info[src_key][node_type_key]
                if tag in allowed_tags:
                    node_tags[node_type_key][tag] = True
                setting = node_type["node_config"].get(key)

                if list_value:
                    unique_settings.add(tuple(setting))
                else:
                    unique_settings.add(setting)

            head_value_str = head_node_config[key]
            if list_value:
                head_value_str = cli_logger.render_list(head_value_str)

            if len(unique_settings) == 1:
                # all node types are configured the same, condense
                # log output
                cli_logger.labeled_value(
                    resource_string + " (all available node types)",
                    "{}",
                    head_value_str,
                    _tags=node_tags[config["head_node_type"]],
                )
            else:
                # do head node type first
                cli_logger.labeled_value(
                    resource_string + f" ({head_node_type})",
                    "{}",
                    head_value_str,
                    _tags=node_tags[head_node_type],
                )

                # go through remaining types
                for node_type_key, node_type in config["available_node_types"].items():
                    if node_type_key == head_node_type:
                        continue
                    workers_value_str = node_type["node_config"][key]
                    if list_value:
                        workers_value_str = cli_logger.render_list(workers_value_str)
                    cli_logger.labeled_value(
                        resource_string + f" ({node_type_key})",
                        "{}",
                        workers_value_str,
                        _tags=node_tags[node_type_key],
                    )

        tags = {"default": _log_info["head_instance_profile_src"] == "default"}
        # head_node_config is the head_node_type's config,
        # config["head_node"] is a field that gets applied only to the actual
        # head node (and not workers of the head's node_type)
        assert (
            "IamInstanceProfile" in head_node_config
            or "IamInstanceProfile" in config["head_node"]
        )
        if "IamInstanceProfile" in head_node_config:
            # If the user manually configured the role we're here.
            IamProfile = head_node_config["IamInstanceProfile"]
        elif "IamInstanceProfile" in config["head_node"]:
            # If we filled the default IAM role, we're here.
            IamProfile = config["head_node"]["IamInstanceProfile"]
        profile_arn = IamProfile.get("Arn")
        profile_name = _arn_to_name(profile_arn) if profile_arn else IamProfile["Name"]
        cli_logger.labeled_value("IAM Profile", "{}", profile_name, _tags=tags)

        if all(
            "KeyName" in node_type["node_config"]
            for node_type in config["available_node_types"].values()
        ):
            print_info("EC2 Key pair", "KeyName", "keypair_src")

        print_info("VPC Subnets", "SubnetIds", "subnet_src", list_value=True)
        print_info(
            "EC2 Security groups",
            "SecurityGroupIds",
            "security_group_src",
            list_value=True,
        )
        print_info("EC2 AMI", "ImageId", "ami_src", allowed_tags=["dlami"])

    cli_logger.newline()