Esempio n. 1
0
    def terminate_node(self, node_id):
        node = self._get_cached_node(node_id)
        if self.cache_stopped_nodes:
            if node.spot_instance_request_id:
                cli_logger.print(
                    "Terminating instance {} " +
                    cf.gray("(cannot stop spot instances, only terminate)"),
                    node_id)  # todo: show node name?

                cli_logger.old_info(
                    logger,
                    "AWSNodeProvider: terminating node {} (spot nodes cannot "
                    "be stopped, only terminated)", node_id)
                node.terminate()
            else:
                cli_logger.print("Stopping instance {} " + cf.gray(
                    "(to terminate instead, "
                    "set `cache_stopped_nodes: False` "
                    "under `provider` in the cluster configuration)"),
                                 node_id)  # todo: show node name?

                cli_logger.old_info(
                    logger,
                    "AWSNodeProvider: stopping node {}. To terminate nodes "
                    "on stop, set 'cache_stopped_nodes: False' in the "
                    "provider config.".format(node_id))
                node.stop()
        else:
            node.terminate()

        self.tag_cache.pop(node_id, None)
        self.tag_cache_pending.pop(node_id, None)
Esempio n. 2
0
    def rsync_down(self, source, target):
        cli_logger.old_info(logger, "{}Syncing {} from {}...", self.log_prefix,
                            source, target)

        self.cmd_runner.run_rsync_down(source, target)
        cli_logger.verbose("`rsync`ed {} (remote) to {} (local)",
                           cf.bold(source), cf.bold(target))
Esempio n. 3
0
def _check_ami(config):
    """Provide helpful message for missing ImageId for node configuration."""

    _set_config_info(head_ami_src="config", workers_ami_src="config")

    region = config["provider"]["region"]
    default_ami = DEFAULT_AMI.get(region)
    if not default_ami:
        # If we do not provide a default AMI for the given region, noop.
        return

    if config["head_node"].get("ImageId", "").lower() == "latest_dlami":
        config["head_node"]["ImageId"] = default_ami
        _set_config_info(head_ami_src="dlami")
        cli_logger.old_info(
            logger, "_check_ami: head node ImageId is 'latest_dlami'. "
            "Using '{ami_id}', which is the default {ami_name} "
            "for your region ({region}).",
            ami_id=default_ami,
            ami_name=DEFAULT_AMI_NAME,
            region=region)

    if config["worker_nodes"].get("ImageId", "").lower() == "latest_dlami":
        config["worker_nodes"]["ImageId"] = default_ami
        _set_config_info(workers_ami_src="dlami")
        cli_logger.old_info(
            logger, "_check_ami: worker nodes ImageId is 'latest_dlami'. "
            "Using '{ami_id}', which is the default {ami_name} "
            "for your region ({region}).",
            ami_id=default_ami,
            ami_name=DEFAULT_AMI_NAME,
            region=region)
Esempio n. 4
0
        def remaining_nodes():
            workers = provider.non_terminated_nodes(
                {TAG_RAY_NODE_KIND: NODE_KIND_WORKER})

            if keep_min_workers:
                min_workers = config.get("min_workers", 0)

                cli_logger.print(
                    "{} random worker nodes will not be shut down. " +
                    cf.dimmed("(due to {})"), cf.bold(min_workers),
                    cf.bold("--keep-min-workers"))
                cli_logger.old_info(logger,
                                    "teardown_cluster: Keeping {} nodes...",
                                    min_workers)

                workers = random.sample(workers, len(workers) - min_workers)

            # todo: it's weird to kill the head node but not all workers
            if workers_only:
                cli_logger.print(
                    "The head node will not be shut down. " +
                    cf.dimmed("(due to {})"), cf.bold("--workers-only"))

                return workers

            head = provider.non_terminated_nodes(
                {TAG_RAY_NODE_KIND: NODE_KIND_HEAD})

            return head + workers
Esempio n. 5
0
def _configure_security_group(config):
    _set_config_info(head_security_group_src="config",
                     workers_security_group_src="config")

    node_types_to_configure = [
        node_type for node_type, config_key in NODE_TYPE_CONFIG_KEYS.items()
        if "SecurityGroupIds" not in config[NODE_TYPE_CONFIG_KEYS[node_type]]
    ]
    if not node_types_to_configure:
        return config  # have user-defined groups

    security_groups = _upsert_security_groups(config, node_types_to_configure)

    if NODE_TYPE_HEAD in node_types_to_configure:
        head_sg = security_groups[NODE_TYPE_HEAD]

        _set_config_info(head_security_group_src="default")
        cli_logger.old_info(
            logger, "_configure_security_group: "
            "SecurityGroupIds not specified for head node, using {} ({})",
            head_sg.group_name, head_sg.id)
        config["head_node"]["SecurityGroupIds"] = [head_sg.id]

    if NODE_TYPE_WORKER in node_types_to_configure:
        workers_sg = security_groups[NODE_TYPE_WORKER]

        _set_config_info(workers_security_group_src="default")
        cli_logger.old_info(
            logger, "_configure_security_group: "
            "SecurityGroupIds not specified for workers, using {} ({})",
            workers_sg.group_name, workers_sg.id)
        config["worker_nodes"]["SecurityGroupIds"] = [workers_sg.id]

    return config
Esempio n. 6
0
    def run(self,
            cmd,
            timeout=120,
            exit_on_fail=False,
            port_forward=None,
            with_output=False,
            ssh_options_override=None,
            **kwargs):
        ssh_options = ssh_options_override or self.ssh_options

        assert isinstance(
            ssh_options, SSHOptions
        ), "ssh_options must be of type SSHOptions, got {}".format(
            type(ssh_options))

        self._set_ssh_ip_if_required()

        if is_using_login_shells():
            ssh = ["ssh", "-tt"]
        else:
            ssh = ["ssh"]

        if port_forward:
            with cli_logger.group("Forwarding ports"):
                if not isinstance(port_forward, list):
                    port_forward = [port_forward]
                for local, remote in port_forward:
                    cli_logger.verbose(
                        "Forwarding port {} to port {} on localhost.",
                        cf.bold(local), cf.bold(remote))  # todo: msg
                    cli_logger.old_info(logger,
                                        "{}Forwarding {} -> localhost:{}",
                                        self.log_prefix, local, remote)
                    ssh += ["-L", "{}:localhost:{}".format(remote, local)]

        final_cmd = ssh + ssh_options.to_ssh_options_list(
            timeout=timeout) + ["{}@{}".format(self.ssh_user, self.ssh_ip)]
        if cmd:
            if is_using_login_shells():
                final_cmd += _with_interactive(cmd)
            else:
                final_cmd += [cmd]
            cli_logger.old_info(logger, "{}Running {}", self.log_prefix,
                                " ".join(final_cmd))
        else:
            # We do this because `-o ControlMaster` causes the `-N` flag to
            # still create an interactive shell in some ssh versions.
            final_cmd.append(quote("while true; do sleep 86400; done"))

        cli_logger.verbose("Running `{}`", cf.bold(cmd))
        with cli_logger.indented():
            cli_logger.very_verbose("Full command is `{}`",
                                    cf.bold(" ".join(final_cmd)))

        if cli_logger.verbosity > 0:
            with cli_logger.indented():
                return self._run_helper(final_cmd, with_output, exit_on_fail)
        else:
            return self._run_helper(final_cmd, with_output, exit_on_fail)
Esempio n. 7
0
    def rsync_down(self, source, target, file_mount=False):
        cli_logger.old_info(logger, "{}Syncing {} from {}...", self.log_prefix,
                            source, target)

        options = {}
        options["file_mount"] = file_mount
        self.cmd_runner.run_rsync_down(source, target, options=options)
        cli_logger.verbose("`rsync`ed {} (remote) to {} (local)",
                           cf.bold(source), cf.bold(target))
Esempio n. 8
0
def _bootstrap_config(config: Dict[str, Any],
                      no_config_cache: bool = False) -> Dict[str, Any]:
    config = prepare_config(config)

    hasher = hashlib.sha1()
    hasher.update(json.dumps([config], sort_keys=True).encode("utf-8"))
    cache_key = os.path.join(tempfile.gettempdir(),
                             "ray-config-{}".format(hasher.hexdigest()))

    if os.path.exists(cache_key) and not no_config_cache:
        cli_logger.old_info(logger, "Using cached config at {}", cache_key)

        config_cache = json.loads(open(cache_key).read())
        if config_cache.get("_version", -1) == CONFIG_CACHE_VERSION:
            # todo: is it fine to re-resolve? afaik it should be.
            # we can have migrations otherwise or something
            # but this seems overcomplicated given that resolving is
            # relatively cheap
            try_reload_log_state(config_cache["config"]["provider"],
                                 config_cache.get("provider_log_info"))
            cli_logger.verbose("Loaded cached config from " + cf.bold("{}"),
                               cache_key)

            return config_cache["config"]
        else:
            cli_logger.warning(
                "Found cached cluster config "
                "but the version " + cf.bold("{}") + " "
                "(expected " + cf.bold("{}") + ") does not match.\n"
                "This is normal if cluster launcher was updated.\n"
                "Config will be re-resolved.",
                config_cache.get("_version", "none"), CONFIG_CACHE_VERSION)
    validate_config(config)

    importer = NODE_PROVIDERS.get(config["provider"]["type"])
    if not importer:
        raise NotImplementedError("Unsupported provider {}".format(
            config["provider"]))

    provider_cls = importer(config["provider"])

    with cli_logger.timed(  # todo: better message
            "Bootstraping {} config",
            PROVIDER_PRETTY_NAMES.get(config["provider"]["type"])):
        resolved_config = provider_cls.bootstrap_config(config)

    if not no_config_cache:
        with open(cache_key, "w") as f:
            config_cache = {
                "_version": CONFIG_CACHE_VERSION,
                "provider_log_info": try_get_log_state(config["provider"]),
                "config": resolved_config
            }
            f.write(json.dumps(config_cache))
    return resolved_config
Esempio n. 9
0
    def wait_ready(self, deadline):
        with cli_logger.group("Waiting for SSH to become available",
                              _numbered=("[]", 1, 6)):
            with LogTimer(self.log_prefix + "Got remote shell"):
                cli_logger.old_info(logger, "{}Waiting for remote shell...",
                                    self.log_prefix)

                cli_logger.print("Running `{}` as a test.", cf.bold("uptime"))
                first_conn_refused_time = None
                while time.time() < deadline and \
                        not self.provider.is_terminated(self.node_id):
                    try:
                        cli_logger.old_debug(logger,
                                             "{}Waiting for remote shell...",
                                             self.log_prefix)

                        # Run outside of the container
                        self.cmd_runner.run("uptime", run_env="host")
                        cli_logger.old_debug(logger, "Uptime succeeded.")
                        cli_logger.success("Success.")
                        return True
                    except ProcessRunnerError as e:
                        first_conn_refused_time = \
                            cmd_output_util.handle_ssh_fails(
                                e, first_conn_refused_time,
                                retry_interval=READY_CHECK_INTERVAL)
                        time.sleep(READY_CHECK_INTERVAL)
                    except Exception as e:
                        # TODO(maximsmol): we should not be ignoring
                        # exceptions if they get filtered properly
                        # (new style log + non-interactive shells)
                        #
                        # however threading this configuration state
                        # is a pain and I'm leaving it for later

                        retry_str = str(e)
                        if hasattr(e, "cmd"):
                            retry_str = "(Exit Status {}): {}".format(
                                e.returncode, " ".join(e.cmd))

                        cli_logger.print(
                            "SSH still not available {}, "
                            "retrying in {} seconds.", cf.dimmed(retry_str),
                            cf.bold(str(READY_CHECK_INTERVAL)))
                        cli_logger.old_debug(logger,
                                             "{}Node not up, retrying: {}",
                                             self.log_prefix, retry_str)

                        time.sleep(READY_CHECK_INTERVAL)

        assert False, "Unable to connect to node"
Esempio n. 10
0
File: updater.py Progetto: aeli0/ray
    def run(self):
        cli_logger.old_info(logger, "{}Updating to {}", self.log_prefix,
                            self.runtime_hash)

        try:
            with LogTimer(self.log_prefix +
                          "Applied config {}".format(self.runtime_hash)):
                self.do_update()
        except Exception as e:
            error_str = str(e)
            if hasattr(e, "cmd"):
                error_str = "(Exit Status {}) {}".format(
                    e.returncode, " ".join(e.cmd))

            self.provider.set_node_tags(
                self.node_id, {TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED})
            cli_logger.error("New status: {}", cf.bold(STATUS_UPDATE_FAILED))

            cli_logger.old_error(logger, "{}Error executing: {}\n",
                                 self.log_prefix, error_str)

            cli_logger.error("!!!")
            if hasattr(e, "cmd"):
                cli_logger.error(
                    "Setup command `{}` failed with exit code {}. stderr:",
                    cf.bold(e.cmd), e.returncode)
            else:
                cli_logger.verbose_error("{}", str(vars(e)))
                # todo: handle this better somehow?
                cli_logger.error("{}", str(e))
            # todo: print stderr here
            cli_logger.error("!!!")
            cli_logger.newline()

            if isinstance(e, click.ClickException):
                # todo: why do we ignore this here
                return
            raise

        tags_to_set = {
            TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
            TAG_RAY_RUNTIME_CONFIG: self.runtime_hash,
        }
        if self.file_mounts_contents_hash is not None:
            tags_to_set[
                TAG_RAY_FILE_MOUNTS_CONTENTS] = self.file_mounts_contents_hash

        self.provider.set_node_tags(self.node_id, tags_to_set)
        cli_logger.labeled_value("New status", STATUS_UP_TO_DATE)

        self.exitcode = 0
Esempio n. 11
0
def kill_node(config_file, yes, hard, override_cluster_name):
    """Kills a random Raylet worker."""

    config = yaml.safe_load(open(config_file).read())
    if override_cluster_name is not None:
        config["cluster_name"] = override_cluster_name
    config = _bootstrap_config(config)

    cli_logger.confirm(yes, "A random node will be killed.")
    cli_logger.old_confirm("This will kill a node in your cluster", yes)

    provider = get_node_provider(config["provider"], config["cluster_name"])
    try:
        nodes = provider.non_terminated_nodes({
            TAG_RAY_NODE_KIND: NODE_KIND_WORKER
        })
        node = random.choice(nodes)
        cli_logger.print("Shutdown " + cf.bold("{}"), node)
        cli_logger.old_info(logger, "kill_node: Shutdown worker {}", node)
        if hard:
            provider.terminate_node(node)
        else:
            updater = NodeUpdaterThread(
                node_id=node,
                provider_config=config["provider"],
                provider=provider,
                auth_config=config["auth"],
                cluster_name=config["cluster_name"],
                file_mounts=config["file_mounts"],
                initialization_commands=[],
                setup_commands=[],
                ray_start_commands=[],
                runtime_hash="",
                file_mounts_contents_hash="",
                is_head_node=False,
                docker_config=config.get("docker"))

            _exec(updater, "ray stop", False, False)

        time.sleep(5)

        if config.get("provider", {}).get("use_internal_ips", False) is True:
            node_ip = provider.internal_ip(node)
        else:
            node_ip = provider.external_ip(node)
    finally:
        provider.cleanup()

    return node_ip
Esempio n. 12
0
    def terminate_nodes(self, node_ids):
        if not node_ids:
            return
        if self.cache_stopped_nodes:
            spot_ids = []
            on_demand_ids = []

            for node_id in node_ids:
                if self._get_cached_node(node_id).spot_instance_request_id:
                    spot_ids += [node_id]
                else:
                    on_demand_ids += [node_id]

            if on_demand_ids:
                # todo: show node names?
                cli_logger.print(
                    "Stopping instances {} " + cf.gray(
                        "(to terminate instead, "
                        "set `cache_stopped_nodes: False` "
                        "under `provider` in the cluster configuration)"),
                    cli_logger.render_list(on_demand_ids))
                cli_logger.old_info(
                    logger,
                    "AWSNodeProvider: stopping nodes {}. To terminate nodes "
                    "on stop, set 'cache_stopped_nodes: False' in the "
                    "provider config.", on_demand_ids)

                self.ec2.meta.client.stop_instances(InstanceIds=on_demand_ids)
            if spot_ids:
                cli_logger.print(
                    "Terminating instances {} " +
                    cf.gray("(cannot stop spot instances, only terminate)"),
                    cli_logger.render_list(spot_ids))
                cli_logger.old_info(
                    logger,
                    "AWSNodeProvider: terminating nodes {} (spot nodes cannot "
                    "be stopped, only terminated)", spot_ids)

                self.ec2.meta.client.terminate_instances(InstanceIds=spot_ids)
        else:
            self.ec2.meta.client.terminate_instances(InstanceIds=node_ids)

        for node_id in node_ids:
            self.tag_cache.pop(node_id, None)
            self.tag_cache_pending.pop(node_id, None)
Esempio n. 13
0
def _create_security_group(config, vpc_id, group_name):
    client = _client("ec2", config)
    client.create_security_group(
        Description="Auto-created security group for Ray workers",
        GroupName=group_name,
        VpcId=vpc_id)
    security_group = _get_security_group(config, vpc_id, group_name)

    cli_logger.verbose("Created new security group {}",
                       cf.bold(security_group.group_name),
                       _tags=dict(id=security_group.id))
    cli_logger.old_info(
        logger, "_create_security_group: Created new security group {} ({})",
        security_group.group_name, security_group.id)

    cli_logger.doassert(security_group,
                        "Failed to create security group")  # err msg
    assert security_group, "Failed to create security group"
    return security_group
Esempio n. 14
0
    def wait_for_ip(self, deadline):
        # if we have IP do not print waiting info
        ip = self._get_node_ip()
        if ip is not None:
            cli_logger.labeled_value("Fetched IP", ip)
            return ip

        interval = 10
        with cli_logger.timed("Waiting for IP"):
            while time.time() < deadline and \
                    not self.provider.is_terminated(self.node_id):
                cli_logger.old_info(logger, "{}Waiting for IP...",
                                    self.log_prefix)

                ip = self._get_node_ip()
                if ip is not None:
                    cli_logger.labeled_value("Received", ip)
                    return ip
                cli_logger.print("Not yet available, retrying in {} seconds",
                                 cf.bold(str(interval)))
                time.sleep(interval)

        return None
Esempio n. 15
0
File: updater.py Progetto: aeli0/ray
    def wait_ready(self, deadline):
        with cli_logger.group(
                "Waiting for SSH to become available", _numbered=("[]", 1, 6)):
            with LogTimer(self.log_prefix + "Got remote shell"):
                cli_logger.old_info(logger, "{}Waiting for remote shell...",
                                    self.log_prefix)

                cli_logger.print("Running `{}` as a test.", cf.bold("uptime"))
                while time.time() < deadline and \
                        not self.provider.is_terminated(self.node_id):
                    try:
                        cli_logger.old_debug(logger,
                                             "{}Waiting for remote shell...",
                                             self.log_prefix)

                        self.cmd_runner.run("uptime")
                        cli_logger.old_debug(logger, "Uptime succeeded.")
                        cli_logger.success("Success.")
                        return True
                    except Exception as e:
                        retry_str = str(e)
                        if hasattr(e, "cmd"):
                            retry_str = "(Exit Status {}): {}".format(
                                e.returncode, " ".join(e.cmd))

                        cli_logger.print(
                            "SSH still not available {}, "
                            "retrying in {} seconds.", cf.gray(retry_str),
                            cf.bold(str(READY_CHECK_INTERVAL)))
                        cli_logger.old_debug(logger,
                                             "{}Node not up, retrying: {}",
                                             self.log_prefix, retry_str)

                        time.sleep(READY_CHECK_INTERVAL)

        assert False, "Unable to connect to node"
Esempio n. 16
0
def _configure_key_pair(config):
    if "ssh_private_key" in config["auth"]:
        _set_config_info(keypair_src="config")

        cli_logger.doassert(  # todo: verify schema beforehand?
            "KeyName" in config["head_node"],
            "`KeyName` missing for head node.")  # todo: err msg
        cli_logger.doassert(
            "KeyName" in config["worker_nodes"],
            "`KeyName` missing for worker nodes.")  # todo: err msg

        assert "KeyName" in config["head_node"]
        assert "KeyName" in config["worker_nodes"]
        return config
    _set_config_info(keypair_src="default")

    ec2 = _resource("ec2", config)

    # Try a few times to get or create a good key pair.
    MAX_NUM_KEYS = 30
    for i in range(MAX_NUM_KEYS):

        key_name = config["provider"].get("key_pair", {}).get("key_name")

        key_name, key_path = key_pair(i, config["provider"]["region"],
                                      key_name)
        key = _get_key(key_name, config)

        # Found a good key.
        if key and os.path.exists(key_path):
            break

        # We can safely create a new key.
        if not key and not os.path.exists(key_path):
            cli_logger.verbose(
                "Creating new key pair {} for use as the default.",
                cf.bold(key_name))
            cli_logger.old_info(
                logger, "_configure_key_pair: "
                "Creating new key pair {}", key_name)
            key = ec2.create_key_pair(KeyName=key_name)

            # We need to make sure to _create_ the file with the right
            # permissions. In order to do that we need to change the default
            # os.open behavior to include the mode we want.
            with open(key_path, "w", opener=partial(os.open, mode=0o600)) as f:
                f.write(key.key_material)
            break

    if not key:
        cli_logger.abort(
            "No matching local key file for any of the key pairs in this "
            "account with ids from 0..{}. "
            "Consider deleting some unused keys pairs from your account.",
            key_name)  # todo: err msg
        raise ValueError(
            "No matching local key file for any of the key pairs in this "
            "account with ids from 0..{}. ".format(key_name) +
            "Consider deleting some unused keys pairs from your account.")

    cli_logger.doassert(os.path.exists(key_path), "Private key file " +
                        cf.bold("{}") + " not found for " + cf.bold("{}"),
                        key_path, key_name)  # todo: err msg
    assert os.path.exists(key_path), \
        "Private key file {} not found for {}".format(key_path, key_name)

    cli_logger.old_info(
        logger, "_configure_key_pair: "
        "KeyName not specified for nodes, using {}", key_name)

    config["auth"]["ssh_private_key"] = key_path
    config["head_node"]["KeyName"] = key_name
    config["worker_nodes"]["KeyName"] = key_name

    return config
Esempio n. 17
0
def _configure_iam_role(config):
    if "IamInstanceProfile" in config["head_node"]:
        _set_config_info(head_instance_profile_src="config")
        return config
    _set_config_info(head_instance_profile_src="default")

    profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)

    if profile is None:
        cli_logger.verbose(
            "Creating new IAM instance profile {} for use as the default.",
            cf.bold(DEFAULT_RAY_INSTANCE_PROFILE))
        cli_logger.old_info(
            logger, "_configure_iam_role: "
            "Creating new instance profile {}", DEFAULT_RAY_INSTANCE_PROFILE)
        client = _client("iam", config)
        client.create_instance_profile(
            InstanceProfileName=DEFAULT_RAY_INSTANCE_PROFILE)
        profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)
        time.sleep(15)  # wait for propagation

    cli_logger.doassert(profile is not None,
                        "Failed to create instance profile.")  # todo: err msg
    assert profile is not None, "Failed to create instance profile"

    if not profile.roles:
        role = _get_role(DEFAULT_RAY_IAM_ROLE, config)
        if role is None:
            cli_logger.verbose(
                "Creating new IAM role {} for "
                "use as the default instance role.",
                cf.bold(DEFAULT_RAY_IAM_ROLE))
            cli_logger.old_info(logger, "_configure_iam_role: "
                                "Creating new role {}", DEFAULT_RAY_IAM_ROLE)
            iam = _resource("iam", config)
            iam.create_role(RoleName=DEFAULT_RAY_IAM_ROLE,
                            AssumeRolePolicyDocument=json.dumps({
                                "Statement": [
                                    {
                                        "Effect": "Allow",
                                        "Principal": {
                                            "Service": "ec2.amazonaws.com"
                                        },
                                        "Action": "sts:AssumeRole",
                                    },
                                ],
                            }))
            role = _get_role(DEFAULT_RAY_IAM_ROLE, config)

            cli_logger.doassert(role is not None,
                                "Failed to create role.")  # todo: err msg
            assert role is not None, "Failed to create role"
        role.attach_policy(
            PolicyArn="arn:aws:iam::aws:policy/AmazonEC2FullAccess")
        role.attach_policy(
            PolicyArn="arn:aws:iam::aws:policy/AmazonS3FullAccess")
        profile.add_role(RoleName=role.name)
        time.sleep(15)  # wait for propagation

    cli_logger.old_info(
        logger, "_configure_iam_role: "
        "Role not specified for head node, using {}", profile.arn)
    config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}

    return config
Esempio n. 18
0
def exec_cluster(config_file: str,
                 *,
                 cmd: Any = None,
                 run_env: str = "auto",
                 screen: bool = False,
                 tmux: bool = False,
                 stop: bool = False,
                 start: bool = False,
                 override_cluster_name: Optional[str] = None,
                 no_config_cache: bool = False,
                 port_forward: Any = None,
                 with_output: bool = False):
    """Runs a command on the specified cluster.

    Arguments:
        config_file: path to the cluster yaml
        cmd: command to run
        run_env: whether to run the command on the host or in a container.
            Select between "auto", "host" and "docker"
        screen: whether to run in a screen
        tmux: whether to run in a tmux session
        stop: whether to stop the cluster after command run
        start: whether to start the cluster if it isn't up
        override_cluster_name: set the name of the cluster
        port_forward (int or list[int]): port(s) to forward
    """
    assert not (screen and tmux), "Can specify only one of `screen` or `tmux`."
    assert run_env in RUN_ENV_TYPES, "--run_env must be in {}".format(
        RUN_ENV_TYPES)
    # TODO(rliaw): We default this to True to maintain backwards-compat.
    # In the future we would want to support disabling login-shells
    # and interactivity.
    cmd_output_util.set_allow_interactive(True)

    config = yaml.safe_load(open(config_file).read())
    if override_cluster_name is not None:
        config["cluster_name"] = override_cluster_name
    config = _bootstrap_config(config, no_config_cache=no_config_cache)

    head_node = _get_head_node(config,
                               config_file,
                               override_cluster_name,
                               create_if_needed=start)

    provider = get_node_provider(config["provider"], config["cluster_name"])
    try:
        updater = NodeUpdaterThread(node_id=head_node,
                                    provider_config=config["provider"],
                                    provider=provider,
                                    auth_config=config["auth"],
                                    cluster_name=config["cluster_name"],
                                    file_mounts=config["file_mounts"],
                                    initialization_commands=[],
                                    setup_commands=[],
                                    ray_start_commands=[],
                                    runtime_hash="",
                                    file_mounts_contents_hash="",
                                    is_head_node=True,
                                    docker_config=config.get("docker"))

        is_docker = isinstance(updater.cmd_runner, DockerCommandRunner)

        if cmd and stop:
            cmd += "; ".join([
                "ray stop",
                "ray teardown ~/ray_bootstrap_config.yaml --yes --workers-only"
            ])
            if is_docker and run_env == "docker":
                updater.cmd_runner.shutdown_after_next_cmd()
            else:
                cmd += "; sudo shutdown -h now"

        result = _exec(updater,
                       cmd,
                       screen,
                       tmux,
                       port_forward=port_forward,
                       with_output=with_output,
                       run_env=run_env)
        if tmux or screen:
            attach_command_parts = ["ray attach", config_file]
            if override_cluster_name is not None:
                attach_command_parts.append(
                    "--cluster-name={}".format(override_cluster_name))
            if tmux:
                attach_command_parts.append("--tmux")
            elif screen:
                attach_command_parts.append("--screen")

            attach_command = " ".join(attach_command_parts)
            cli_logger.print("Run `{}` to check command status.",
                             cf.bold(attach_command))

            attach_info = "Use `{}` to check on command status.".format(
                attach_command)
            cli_logger.old_info(logger, attach_info)
        return result
    finally:
        provider.cleanup()
Esempio n. 19
0
def get_or_create_head_node(config,
                            config_file,
                            no_restart,
                            restart_only,
                            yes,
                            override_cluster_name,
                            _provider=None,
                            _runner=subprocess):
    """Create the cluster head node, which in turn creates the workers."""
    provider = (_provider or get_node_provider(config["provider"],
                                               config["cluster_name"]))

    config = copy.deepcopy(config)
    raw_config_file = config_file  # used for printing to the user
    config_file = os.path.abspath(config_file)
    try:
        head_node_tags = {
            TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
        }
        nodes = provider.non_terminated_nodes(head_node_tags)
        if len(nodes) > 0:
            head_node = nodes[0]
        else:
            head_node = None

        if not head_node:
            cli_logger.confirm(yes, "No head node found. "
                               "Launching a new cluster.",
                               _abort=True)
            cli_logger.old_confirm("This will create a new cluster", yes)
        elif not no_restart:
            cli_logger.old_confirm("This will restart cluster services", yes)

        if head_node:
            if restart_only:
                cli_logger.confirm(
                    yes, "Updating cluster configuration and "
                    "restarting the cluster Ray runtime. "
                    "Setup commands will not be run due to `{}`.\n",
                    cf.bold("--restart-only"),
                    _abort=True)
            elif no_restart:
                cli_logger.print(
                    "Cluster Ray runtime will not be restarted due "
                    "to `{}`.", cf.bold("--no-restart"))
                cli_logger.confirm(yes, "Updating cluster configuration and "
                                   "running setup commands.",
                                   _abort=True)
            else:
                cli_logger.print(
                    "Updating cluster configuration and running full setup.")
                cli_logger.confirm(
                    yes,
                    cf.bold("Cluster Ray runtime will be restarted."),
                    _abort=True)
        cli_logger.newline()

        # TODO(ekl) this logic is duplicated in node_launcher.py (keep in sync)
        head_node_config = copy.deepcopy(config["head_node"])
        if "head_node_type" in config:
            head_node_tags[TAG_RAY_USER_NODE_TYPE] = config["head_node_type"]
            head_node_config.update(config["available_node_types"][
                config["head_node_type"]]["node_config"])

        launch_hash = hash_launch_conf(head_node_config, config["auth"])
        if head_node is None or provider.node_tags(head_node).get(
                TAG_RAY_LAUNCH_CONFIG) != launch_hash:
            with cli_logger.group("Acquiring an up-to-date head node"):
                if head_node is not None:
                    cli_logger.print(
                        "Currently running head node is out-of-date with "
                        "cluster configuration")
                    cli_logger.print(
                        "hash is {}, expected {}",
                        cf.bold(
                            provider.node_tags(head_node).get(
                                TAG_RAY_LAUNCH_CONFIG)), cf.bold(launch_hash))
                    cli_logger.confirm(yes, "Relaunching it.", _abort=True)
                    cli_logger.old_confirm(
                        "Head node config out-of-date. It will be terminated",
                        yes)

                    cli_logger.old_info(
                        logger, "get_or_create_head_node: "
                        "Shutting down outdated head node {}", head_node)

                    provider.terminate_node(head_node)
                    cli_logger.print("Terminated head node {}", head_node)

                cli_logger.old_info(
                    logger,
                    "get_or_create_head_node: Launching new head node...")

                head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
                head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
                    config["cluster_name"])
                provider.create_node(head_node_config, head_node_tags, 1)
                cli_logger.print("Launched a new head node")

                start = time.time()
                head_node = None
                with cli_logger.timed("Fetching the new head node"):
                    while True:
                        if time.time() - start > 50:
                            cli_logger.abort(
                                "Head node fetch timed out.")  # todo: msg
                            raise RuntimeError("Failed to create head node.")
                        nodes = provider.non_terminated_nodes(head_node_tags)
                        if len(nodes) == 1:
                            head_node = nodes[0]
                            break
                        time.sleep(1)
                cli_logger.newline()

        with cli_logger.group(
                "Setting up head node",
                _numbered=("<>", 1, 1),
                # cf.bold(provider.node_tags(head_node)[TAG_RAY_NODE_NAME]),
                _tags=dict()):  # add id, ARN to tags?

            # TODO(ekl) right now we always update the head node even if the
            # hash matches.
            # We could prompt the user for what they want to do here.
            # No need to pass in cluster_sync_files because we use this
            # hash to set up the head node
            (runtime_hash, file_mounts_contents_hash) = hash_runtime_conf(
                config["file_mounts"], None, config)

            cli_logger.old_info(
                logger,
                "get_or_create_head_node: Updating files on head node...")

            # Rewrite the auth config so that the head
            # node can update the workers
            remote_config = copy.deepcopy(config)

            # drop proxy options if they exist, otherwise
            # head node won't be able to connect to workers
            remote_config["auth"].pop("ssh_proxy_command", None)

            if "ssh_private_key" in config["auth"]:
                remote_key_path = "~/ray_bootstrap_key.pem"
                remote_config["auth"]["ssh_private_key"] = remote_key_path

            # Adjust for new file locations
            new_mounts = {}
            for remote_path in config["file_mounts"]:
                new_mounts[remote_path] = remote_path
            remote_config["file_mounts"] = new_mounts
            remote_config["no_restart"] = no_restart

            # Now inject the rewritten config and SSH key into the head node
            remote_config_file = tempfile.NamedTemporaryFile(
                "w", prefix="ray-bootstrap-")
            remote_config_file.write(json.dumps(remote_config))
            remote_config_file.flush()
            config["file_mounts"].update(
                {"~/ray_bootstrap_config.yaml": remote_config_file.name})

            if "ssh_private_key" in config["auth"]:
                config["file_mounts"].update({
                    remote_key_path:
                    config["auth"]["ssh_private_key"],
                })
            cli_logger.print("Prepared bootstrap config")

            if restart_only:
                setup_commands = []
                ray_start_commands = config["head_start_ray_commands"]
            elif no_restart:
                setup_commands = config["head_setup_commands"]
                ray_start_commands = []
            else:
                setup_commands = config["head_setup_commands"]
                ray_start_commands = config["head_start_ray_commands"]

            if not no_restart:
                warn_about_bad_start_command(ray_start_commands)

            updater = NodeUpdaterThread(
                node_id=head_node,
                provider_config=config["provider"],
                provider=provider,
                auth_config=config["auth"],
                cluster_name=config["cluster_name"],
                file_mounts=config["file_mounts"],
                initialization_commands=config["initialization_commands"],
                setup_commands=setup_commands,
                ray_start_commands=ray_start_commands,
                process_runner=_runner,
                runtime_hash=runtime_hash,
                file_mounts_contents_hash=file_mounts_contents_hash,
                is_head_node=True,
                docker_config=config.get("docker"))
            updater.start()
            updater.join()

            # Refresh the node cache so we see the external ip if available
            provider.non_terminated_nodes(head_node_tags)

            if config.get("provider", {}).get("use_internal_ips",
                                              False) is True:
                head_node_ip = provider.internal_ip(head_node)
            else:
                head_node_ip = provider.external_ip(head_node)

            if updater.exitcode != 0:
                # todo: this does not follow the mockup and is not good enough
                cli_logger.abort("Failed to setup head node.")

                cli_logger.old_error(
                    logger, "get_or_create_head_node: "
                    "Updating {} failed", head_node_ip)
                sys.exit(1)

            cli_logger.old_info(
                logger, "get_or_create_head_node: "
                "Head node up-to-date, IP address is: {}", head_node_ip)

        monitor_str = "tail -n 100 -f /tmp/ray/session_*/logs/monitor*"
        if override_cluster_name:
            modifiers = " --cluster-name={}".format(
                quote(override_cluster_name))
        else:
            modifiers = ""

        if cli_logger.old_style:
            print("To monitor autoscaling activity, you can run:\n\n"
                  "  ray exec {} {}{}\n".format(config_file,
                                                quote(monitor_str), modifiers))
            print("To open a console on the cluster:\n\n"
                  "  ray attach {}{}\n".format(config_file, modifiers))

            print("To get a remote shell to the cluster manually, run:\n\n"
                  "  {}\n".format(
                      updater.cmd_runner.remote_shell_command_str()))

        cli_logger.newline()
        with cli_logger.group("Useful commands"):
            cli_logger.print("Monitor autoscaling with")
            cli_logger.print(cf.bold("  ray exec {}{} {}"), raw_config_file,
                             modifiers, quote(monitor_str))

            cli_logger.print("Connect to a terminal on the cluster head")
            cli_logger.print(cf.bold("  ray attach {}{}"), raw_config_file,
                             modifiers)
    finally:
        provider.cleanup()
Esempio n. 20
0
    def _create_node(self, node_config, tags, count):
        tags = to_aws_format(tags)
        conf = node_config.copy()

        # Delete unsupported keys from the node config
        try:
            del conf["Resources"]
        except KeyError:
            pass

        tag_pairs = [{
            "Key": TAG_RAY_CLUSTER_NAME,
            "Value": self.cluster_name,
        }]
        for k, v in tags.items():
            tag_pairs.append({
                "Key": k,
                "Value": v,
            })
        tag_specs = [{
            "ResourceType": "instance",
            "Tags": tag_pairs,
        }]
        user_tag_specs = conf.get("TagSpecifications", [])
        # Allow users to add tags and override values of existing
        # tags with their own. This only applies to the resource type
        # "instance". All other resource types are appended to the list of
        # tag specs.
        for user_tag_spec in user_tag_specs:
            if user_tag_spec["ResourceType"] == "instance":
                for user_tag in user_tag_spec["Tags"]:
                    exists = False
                    for tag in tag_specs[0]["Tags"]:
                        if user_tag["Key"] == tag["Key"]:
                            exists = True
                            tag["Value"] = user_tag["Value"]
                            break
                    if not exists:
                        tag_specs[0]["Tags"] += [user_tag]
            else:
                tag_specs += [user_tag_spec]

        # SubnetIds is not a real config key: we must resolve to a
        # single SubnetId before invoking the AWS API.
        subnet_ids = conf.pop("SubnetIds")

        for attempt in range(1, BOTO_CREATE_MAX_RETRIES + 1):
            try:
                subnet_id = subnet_ids[self.subnet_idx % len(subnet_ids)]

                cli_logger.old_info(
                    logger, "NodeProvider: calling create_instances "
                    "with {} (count={}).", subnet_id, count)

                self.subnet_idx += 1
                conf.update({
                    "MinCount": 1,
                    "MaxCount": count,
                    "SubnetId": subnet_id,
                    "TagSpecifications": tag_specs
                })
                created = self.ec2_fail_fast.create_instances(**conf)

                # todo: timed?
                # todo: handle plurality?
                with cli_logger.group(
                        "Launching {} nodes",
                        count,
                        _tags=dict(subnet_id=subnet_id)):
                    for instance in created:
                        cli_logger.print(
                            "Launched instance {}",
                            instance.instance_id,
                            _tags=dict(
                                state=instance.state["Name"],
                                info=instance.state_reason["Message"]))
                        cli_logger.old_info(
                            logger, "NodeProvider: Created instance "
                            "[id={}, name={}, info={}]", instance.instance_id,
                            instance.state["Name"],
                            instance.state_reason["Message"])
                break
            except botocore.exceptions.ClientError as exc:
                if attempt == BOTO_CREATE_MAX_RETRIES:
                    # todo: err msg
                    cli_logger.abort(
                        "Failed to launch instances. Max attempts exceeded.")
                    cli_logger.old_error(
                        logger,
                        "create_instances: Max attempts ({}) exceeded.",
                        BOTO_CREATE_MAX_RETRIES)
                    raise exc
                else:
                    # todo: err msg
                    cli_logger.abort(exc)
                    cli_logger.old_error(logger, exc)
Esempio n. 21
0
    def do_update(self):
        self.provider.set_node_tags(
            self.node_id, {TAG_RAY_NODE_STATUS: STATUS_WAITING_FOR_SSH})
        cli_logger.labeled_value("New status", STATUS_WAITING_FOR_SSH)

        deadline = time.time() + NODE_START_WAIT_S
        self.wait_ready(deadline)

        node_tags = self.provider.node_tags(self.node_id)
        logger.debug("Node tags: {}".format(str(node_tags)))

        # runtime_hash will only change whenever the user restarts
        # or updates their cluster with `get_or_create_head_node`
        if node_tags.get(TAG_RAY_RUNTIME_CONFIG) == self.runtime_hash and (
                self.file_mounts_contents_hash is None
                or node_tags.get(TAG_RAY_FILE_MOUNTS_CONTENTS)
                == self.file_mounts_contents_hash):
            # todo: we lie in the confirmation message since
            # full setup might be cancelled here
            cli_logger.print(
                "Configuration already up to date, "
                "skipping file mounts, initalization and setup commands.",
                _numbered=("[]", "2-5", 6))
            cli_logger.old_info(logger,
                                "{}{} already up-to-date, skip to ray start",
                                self.log_prefix, self.node_id)

            # When resuming from a stopped instance the runtime_hash may be the
            # same, but the container will not be started.
            self.cmd_runner.run_init(as_head=self.is_head_node,
                                     file_mounts=self.file_mounts)

        else:
            cli_logger.print("Updating cluster configuration.",
                             _tags=dict(hash=self.runtime_hash))

            self.provider.set_node_tags(
                self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SYNCING_FILES})
            cli_logger.labeled_value("New status", STATUS_SYNCING_FILES)
            self.sync_file_mounts(self.rsync_up, step_numbers=(2, 6))

            # Only run setup commands if runtime_hash has changed because
            # we don't want to run setup_commands every time the head node
            # file_mounts folders have changed.
            if node_tags.get(TAG_RAY_RUNTIME_CONFIG) != self.runtime_hash:
                # Run init commands
                self.provider.set_node_tags(
                    self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SETTING_UP})
                cli_logger.labeled_value("New status", STATUS_SETTING_UP)

                if self.initialization_commands:
                    with cli_logger.group("Running initialization commands",
                                          _numbered=("[]", 3, 5)):
                        with LogTimer(self.log_prefix +
                                      "Initialization commands",
                                      show_status=True):
                            for cmd in self.initialization_commands:
                                try:
                                    # Overriding the existing SSHOptions class
                                    # with a new SSHOptions class that uses
                                    # this ssh_private_key as its only __init__
                                    # argument.
                                    # Run outside docker.
                                    self.cmd_runner.run(
                                        cmd,
                                        ssh_options_override_ssh_key=self.
                                        auth_config.get("ssh_private_key"),
                                        run_env="host")
                                except ProcessRunnerError as e:
                                    if e.msg_type == "ssh_command_failed":
                                        cli_logger.error("Failed.")
                                        cli_logger.error(
                                            "See above for stderr.")

                                    raise click.ClickException(
                                        "Initialization command failed."
                                    ) from None
                else:
                    cli_logger.print("No initialization commands to run.",
                                     _numbered=("[]", 3, 6))
                self.cmd_runner.run_init(as_head=self.is_head_node,
                                         file_mounts=self.file_mounts)
                if self.setup_commands:
                    with cli_logger.group(
                            "Running setup commands",
                            # todo: fix command numbering
                            _numbered=("[]", 4, 6)):
                        with LogTimer(self.log_prefix + "Setup commands",
                                      show_status=True):

                            total = len(self.setup_commands)
                            for i, cmd in enumerate(self.setup_commands):
                                if cli_logger.verbosity == 0 and len(cmd) > 30:
                                    cmd_to_print = cf.bold(cmd[:30]) + "..."
                                else:
                                    cmd_to_print = cf.bold(cmd)

                                cli_logger.print("{}",
                                                 cmd_to_print,
                                                 _numbered=("()", i, total))

                                try:
                                    # Runs in the container if docker is in use
                                    self.cmd_runner.run(cmd, run_env="auto")
                                except ProcessRunnerError as e:
                                    if e.msg_type == "ssh_command_failed":
                                        cli_logger.error("Failed.")
                                        cli_logger.error(
                                            "See above for stderr.")

                                    raise click.ClickException(
                                        "Setup command failed.")
                else:
                    cli_logger.print("No setup commands to run.",
                                     _numbered=("[]", 4, 6))

        with cli_logger.group("Starting the Ray runtime",
                              _numbered=("[]", 6, 6)):
            with LogTimer(self.log_prefix + "Ray start commands",
                          show_status=True):
                for cmd in self.ray_start_commands:
                    if self.node_resources:
                        env_vars = {
                            ray_constants.RESOURCES_ENVIRONMENT_VARIABLE:
                            self.node_resources
                        }
                    else:
                        env_vars = {}
                    try:
                        old_redirected = cmd_output_util.is_output_redirected()
                        cmd_output_util.set_output_redirected(False)
                        # Runs in the container if docker is in use
                        self.cmd_runner.run(cmd,
                                            environment_variables=env_vars,
                                            run_env="auto")
                        cmd_output_util.set_output_redirected(old_redirected)
                    except ProcessRunnerError as e:
                        if e.msg_type == "ssh_command_failed":
                            cli_logger.error("Failed.")
                            cli_logger.error("See above for stderr.")

                        raise click.ClickException("Start command failed.")
Esempio n. 22
0
    def run(self,
            cmd,
            timeout=120,
            exit_on_fail=False,
            port_forward=None,
            with_output=False,
            ssh_options_override=None,
            **kwargs):
        ssh_options = ssh_options_override or self.ssh_options

        assert isinstance(
            ssh_options, SSHOptions
        ), "ssh_options must be of type SSHOptions, got {}".format(
            type(ssh_options))

        self._set_ssh_ip_if_required()

        ssh = ["ssh", "-tt"]

        if port_forward:
            with cli_logger.group("Forwarding ports"):
                if not isinstance(port_forward, list):
                    port_forward = [port_forward]
                for local, remote in port_forward:
                    cli_logger.verbose(
                        "Forwarding port {} to port {} on localhost.",
                        cf.bold(local), cf.bold(remote))  # todo: msg
                    cli_logger.old_info(logger,
                                        "{}Forwarding {} -> localhost:{}",
                                        self.log_prefix, local, remote)
                    ssh += ["-L", "{}:localhost:{}".format(remote, local)]

        final_cmd = ssh + ssh_options.to_ssh_options_list(
            timeout=timeout) + ["{}@{}".format(self.ssh_user, self.ssh_ip)]
        if cmd:
            final_cmd += _with_interactive(cmd)
            cli_logger.old_info(logger, "{}Running {}", self.log_prefix,
                                " ".join(final_cmd))
        else:
            # We do this because `-o ControlMaster` causes the `-N` flag to
            # still create an interactive shell in some ssh versions.
            final_cmd.append(quote("while true; do sleep 86400; done"))

        # todo: add a flag for this, we might
        # wanna log commands with print sometimes
        cli_logger.verbose("Running `{}`", cf.bold(cmd))
        with cli_logger.indented():
            cli_logger.very_verbose("Full command is `{}`",
                                    cf.bold(" ".join(final_cmd)))

        def start_process():
            try:
                if with_output:
                    return self.process_runner.check_output(final_cmd)
                else:
                    self.process_runner.check_call(final_cmd)
            except subprocess.CalledProcessError as e:
                quoted_cmd = " ".join(final_cmd[:-1] + [quote(final_cmd[-1])])
                if not cli_logger.old_style:
                    raise ProcessRunnerError("Command failed",
                                             "ssh_command_failed",
                                             code=e.returncode,
                                             command=quoted_cmd)

                if exit_on_fail:
                    raise click.ClickException(
                        "Command failed: \n\n  {}\n".format(quoted_cmd)) \
                        from None
                else:
                    raise click.ClickException(
                        "SSH command Failed. See above for the output from the"
                        " failure.") from None

        if cli_logger.verbosity > 0:
            with cli_logger.indented():
                return start_process()
        else:
            return start_process()
Esempio n. 23
0
    def create_node(self, node_config, tags, count):
        # Always add the instance type tag, since node reuse is unsafe
        # otherwise.
        tags = copy.deepcopy(tags)
        tags[TAG_RAY_INSTANCE_TYPE] = node_config["InstanceType"]
        # Try to reuse previously stopped nodes with compatible configs
        if self.cache_stopped_nodes:
            filters = [
                {
                    "Name": "instance-state-name",
                    "Values": ["stopped", "stopping"],
                },
                {
                    "Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME),
                    "Values": [self.cluster_name],
                },
                {
                    "Name": "tag:{}".format(TAG_RAY_NODE_TYPE),
                    "Values": [tags[TAG_RAY_NODE_TYPE]],
                },
                {
                    "Name": "tag:{}".format(TAG_RAY_INSTANCE_TYPE),
                    "Values": [tags[TAG_RAY_INSTANCE_TYPE]],
                },
                {
                    "Name": "tag:{}".format(TAG_RAY_LAUNCH_CONFIG),
                    "Values": [tags[TAG_RAY_LAUNCH_CONFIG]],
                },
            ]

            reuse_nodes = list(
                self.ec2.instances.filter(Filters=filters))[:count]
            reuse_node_ids = [n.id for n in reuse_nodes]
            if reuse_nodes:
                cli_logger.print(
                    # todo: handle plural vs singular?
                    "Reusing nodes {}. "
                    "To disable reuse, set `cache_stopped_nodes: False` "
                    "under `provider` in the cluster configuration.",
                    cli_logger.render_list(reuse_node_ids))
                cli_logger.old_info(
                    logger, "AWSNodeProvider: reusing instances {}. "
                    "To disable reuse, set "
                    "'cache_stopped_nodes: False' in the provider "
                    "config.", reuse_node_ids)

                # todo: timed?
                with cli_logger.group("Stopping instances to reuse"):
                    for node in reuse_nodes:
                        self.tag_cache[node.id] = from_aws_format(
                            {x["Key"]: x["Value"]
                             for x in node.tags})
                        if node.state["Name"] == "stopping":
                            cli_logger.print("Waiting for instance {} to stop",
                                             node.id)
                            cli_logger.old_info(
                                logger,
                                "AWSNodeProvider: waiting for instance "
                                "{} to fully stop...", node.id)
                            node.wait_until_stopped()

                self.ec2.meta.client.start_instances(
                    InstanceIds=reuse_node_ids)
                for node_id in reuse_node_ids:
                    self.set_node_tags(node_id, tags)
                count -= len(reuse_node_ids)

        if count:
            self._create_node(node_config, tags, count)
Esempio n. 24
0
    def resolve(self, is_head, node_ip_address=None):
        """Returns a copy with values filled out with system defaults.

        Args:
            is_head (bool): Whether this is the head node.
            node_ip_address (str): The IP address of the node that we are on.
                This is used to automatically create a node id resource.
        """

        resources = (self.resources or {}).copy()
        assert "CPU" not in resources, resources
        assert "GPU" not in resources, resources
        assert "memory" not in resources, resources
        assert "object_store_memory" not in resources, resources

        if node_ip_address is None:
            node_ip_address = ray.services.get_node_ip_address()

        # Automatically create a node id resource on each node. This is
        # queryable with ray.state.node_ids() and ray.state.current_node_id().
        resources[NODE_ID_PREFIX + node_ip_address] = 1.0

        num_cpus = self.num_cpus
        if num_cpus is None:
            num_cpus = multiprocessing.cpu_count()

        num_gpus = self.num_gpus
        gpu_ids = ray.utils.get_cuda_visible_devices()
        # Check that the number of GPUs that the raylet wants doesn't
        # excede the amount allowed by CUDA_VISIBLE_DEVICES.
        if (num_gpus is not None and gpu_ids is not None
                and num_gpus > len(gpu_ids)):
            raise ValueError("Attempting to start raylet with {} GPUs, "
                             "but CUDA_VISIBLE_DEVICES contains {}.".format(
                                 num_gpus, gpu_ids))
        if num_gpus is None:
            # Try to automatically detect the number of GPUs.
            num_gpus = _autodetect_num_gpus()
            # Don't use more GPUs than allowed by CUDA_VISIBLE_DEVICES.
            if gpu_ids is not None:
                num_gpus = min(num_gpus, len(gpu_ids))

        try:
            info_string = _get_gpu_info_string()
            gpu_types = _constraints_from_gpu_info(info_string)
            resources.update(gpu_types)
        except Exception:
            logger.exception("Could not parse gpu information.")

        # Choose a default object store size.
        system_memory = ray.utils.get_system_memory()
        avail_memory = ray.utils.estimate_available_memory()
        object_store_memory = self.object_store_memory
        if object_store_memory is None:
            object_store_memory = int(avail_memory * 0.3)
            # Cap memory to avoid memory waste and perf issues on large nodes
            if (object_store_memory >
                    ray_constants.DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES):
                logger.debug(
                    "Warning: Capping object memory store to {}GB. ".format(
                        ray_constants.DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES //
                        1e9) +
                    "To increase this further, specify `object_store_memory` "
                    "when calling ray.init() or ray start.")
                object_store_memory = (
                    ray_constants.DEFAULT_OBJECT_STORE_MAX_MEMORY_BYTES)

        redis_max_memory = self.redis_max_memory
        if redis_max_memory is None:
            redis_max_memory = min(
                ray_constants.DEFAULT_REDIS_MAX_MEMORY_BYTES,
                max(int(avail_memory * 0.1),
                    ray_constants.REDIS_MINIMUM_MEMORY_BYTES))
        if redis_max_memory < ray_constants.REDIS_MINIMUM_MEMORY_BYTES:
            raise ValueError(
                "Attempting to cap Redis memory usage at {} bytes, "
                "but the minimum allowed is {} bytes.".format(
                    redis_max_memory,
                    ray_constants.REDIS_MINIMUM_MEMORY_BYTES))

        memory = self.memory
        if memory is None:
            memory = (avail_memory - object_store_memory -
                      (redis_max_memory if is_head else 0))
            if memory < 100e6 and memory < 0.05 * system_memory:
                raise ValueError(
                    "After taking into account object store and redis memory "
                    "usage, the amount of memory on this node available for "
                    "tasks and actors ({} GB) is less than {}% of total. "
                    "You can adjust these settings with "
                    "ray.init(memory=<bytes>, "
                    "object_store_memory=<bytes>).".format(
                        round(memory / 1e9, 2),
                        int(100 * (memory / system_memory))))

        rounded_memory = ray_constants.round_to_memory_units(memory,
                                                             round_up=False)
        worker_ram = round(rounded_memory / (1024**3), 2)
        object_ram = round(object_store_memory / (1024**3), 2)

        # TODO(maximsmol): this behavior is strange since we do not have a
        # good grasp on when this will get called
        # (you have to study node.py to make a guess)
        with cli_logger.group("Available RAM"):
            cli_logger.labeled_value("Workers", "{} GiB", str(worker_ram))
            cli_logger.labeled_value("Objects", "{} GiB", str(object_ram))
            cli_logger.newline()
            cli_logger.print("To adjust these values, use")
            with cf.with_style("monokai") as c:
                cli_logger.print(
                    "  ray{0}init(memory{1}{2}, "
                    "object_store_memory{1}{2})", c.magenta("."),
                    c.magenta("="), c.purple("<bytes>"))

        cli_logger.old_info(
            logger,
            "Starting Ray with {} GiB memory available for workers and up to "
            "{} GiB for objects. You can adjust these settings "
            "with ray.init(memory=<bytes>, "
            "object_store_memory=<bytes>).", worker_ram, object_ram)

        spec = ResourceSpec(num_cpus, num_gpus, memory, object_store_memory,
                            resources, redis_max_memory)
        assert spec.resolved()
        return spec
Esempio n. 25
0
def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
                     override_cluster_name: Optional[str],
                     keep_min_workers: bool, log_old_style: bool,
                     log_color: str, verbose: int):
    """Destroys all nodes of a Ray cluster described by a config json."""
    cli_logger.old_style = log_old_style
    cli_logger.color_mode = log_color
    cli_logger.verbosity = verbose
    cli_logger.dump_command_output = verbose == 3  # todo: add a separate flag?

    config = yaml.safe_load(open(config_file).read())
    if override_cluster_name is not None:
        config["cluster_name"] = override_cluster_name
    config = prepare_config(config)
    validate_config(config)

    cli_logger.confirm(yes, "Destroying cluster.", _abort=True)
    cli_logger.old_confirm("This will destroy your cluster", yes)

    if not workers_only:
        try:
            exec_cluster(config_file,
                         cmd="ray stop",
                         run_env="auto",
                         screen=False,
                         tmux=False,
                         stop=False,
                         start=False,
                         override_cluster_name=override_cluster_name,
                         port_forward=None,
                         with_output=False)
        except Exception as e:
            # todo: add better exception info
            cli_logger.verbose_error("{}", str(e))
            cli_logger.warning(
                "Exception occured when stopping the cluster Ray runtime "
                "(use -v to dump teardown exceptions).")
            cli_logger.warning(
                "Ignoring the exception and "
                "attempting to shut down the cluster nodes anyway.")

            cli_logger.old_exception(
                logger, "Ignoring error attempting a clean shutdown.")

    provider = get_node_provider(config["provider"], config["cluster_name"])
    try:

        def remaining_nodes():

            workers = provider.non_terminated_nodes(
                {TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER})

            if keep_min_workers:
                min_workers = config.get("min_workers", 0)

                cli_logger.print(
                    "{} random worker nodes will not be shut down. " +
                    cf.gray("(due to {})"), cf.bold(min_workers),
                    cf.bold("--keep-min-workers"))
                cli_logger.old_info(logger,
                                    "teardown_cluster: Keeping {} nodes...",
                                    min_workers)

                workers = random.sample(workers, len(workers) - min_workers)

            # todo: it's weird to kill the head node but not all workers
            if workers_only:
                cli_logger.print(
                    "The head node will not be shut down. " +
                    cf.gray("(due to {})"), cf.bold("--workers-only"))

                return workers

            head = provider.non_terminated_nodes(
                {TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD})

            return head + workers

        # Loop here to check that both the head and worker nodes are actually
        #   really gone
        A = remaining_nodes()
        with LogTimer("teardown_cluster: done."):
            while A:
                cli_logger.old_info(
                    logger, "teardown_cluster: "
                    "Shutting down {} nodes...", len(A))

                provider.terminate_nodes(A)

                cli_logger.print("Requested {} nodes to shut down.",
                                 cf.bold(len(A)),
                                 _tags=dict(interval="1s"))

                time.sleep(1)  # todo: interval should be a variable
                A = remaining_nodes()
                cli_logger.print("{} nodes remaining after 1 second.",
                                 cf.bold(len(A)))
    finally:
        provider.cleanup()
Esempio n. 26
0
def _configure_subnet(config):
    ec2 = _resource("ec2", config)
    use_internal_ips = config["provider"].get("use_internal_ips", False)

    try:
        subnets = sorted(
            (s for s in ec2.subnets.all() if s.state == "available" and (
                use_internal_ips or s.map_public_ip_on_launch)),
            reverse=True,  # sort from Z-A
            key=lambda subnet: subnet.availability_zone)
    except botocore.exceptions.ClientError as exc:
        handle_boto_error(exc, "Failed to fetch available subnets from AWS.")
        raise exc

    if not subnets:
        cli_logger.abort(
            "No usable subnets found, try manually creating an instance in "
            "your specified region to populate the list of subnets "
            "and trying this again.\n"
            "Note that the subnet must map public IPs "
            "on instance launch unless you set `use_internal_ips: true` in "
            "the `provider` config.")  # todo: err msg
        raise Exception(
            "No usable subnets found, try manually creating an instance in "
            "your specified region to populate the list of subnets "
            "and trying this again. Note that the subnet must map public IPs "
            "on instance launch unless you set 'use_internal_ips': True in "
            "the 'provider' config.")
    if "availability_zone" in config["provider"]:
        azs = config["provider"]["availability_zone"].split(",")
        subnets = [s for s in subnets if s.availability_zone in azs]
        if not subnets:
            cli_logger.abort(
                "No usable subnets matching availability zone {} found.\n"
                "Choose a different availability zone or try "
                "manually creating an instance in your specified region "
                "to populate the list of subnets and trying this again.",
                config["provider"]["availability_zone"])  # todo: err msg
            raise Exception(
                "No usable subnets matching availability zone {} "
                "found. Choose a different availability zone or try "
                "manually creating an instance in your specified region "
                "to populate the list of subnets and trying this again.".
                format(config["provider"]["availability_zone"]))

    subnet_ids = [s.subnet_id for s in subnets]
    subnet_descr = [(s.subnet_id, s.availability_zone) for s in subnets]
    if "SubnetIds" not in config["head_node"]:
        _set_config_info(head_subnet_src="default")
        config["head_node"]["SubnetIds"] = subnet_ids
        cli_logger.old_info(
            logger, "_configure_subnet: "
            "SubnetIds not specified for head node, using {}", subnet_descr)
    else:
        _set_config_info(head_subnet_src="config")

    if "SubnetIds" not in config["worker_nodes"]:
        _set_config_info(workers_subnet_src="default")
        config["worker_nodes"]["SubnetIds"] = subnet_ids
        cli_logger.old_info(
            logger, "_configure_subnet: "
            "SubnetId not specified for workers,"
            " using {}", subnet_descr)
    else:
        _set_config_info(workers_subnet_src="config")

    return config
Esempio n. 27
0
def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
                     override_cluster_name: Optional[str],
                     keep_min_workers: bool):
    """Destroys all nodes of a Ray cluster described by a config json."""
    config = yaml.safe_load(open(config_file).read())
    if override_cluster_name is not None:
        config["cluster_name"] = override_cluster_name
    config = prepare_config(config)
    validate_config(config)

    cli_logger.confirm(yes, "Destroying cluster.", _abort=True)
    cli_logger.old_confirm("This will destroy your cluster", yes)

    if not workers_only:
        try:
            exec_cluster(config_file,
                         cmd="ray stop",
                         run_env="auto",
                         screen=False,
                         tmux=False,
                         stop=False,
                         start=False,
                         override_cluster_name=override_cluster_name,
                         port_forward=None,
                         with_output=False)
        except Exception as e:
            # todo: add better exception info
            cli_logger.verbose_error("{}", str(e))
            cli_logger.warning(
                "Exception occured when stopping the cluster Ray runtime "
                "(use -v to dump teardown exceptions).")
            cli_logger.warning(
                "Ignoring the exception and "
                "attempting to shut down the cluster nodes anyway.")

            cli_logger.old_exception(
                logger, "Ignoring error attempting a clean shutdown.")

    provider = get_node_provider(config["provider"], config["cluster_name"])
    try:

        def remaining_nodes():
            workers = provider.non_terminated_nodes(
                {TAG_RAY_NODE_KIND: NODE_KIND_WORKER})

            if keep_min_workers:
                min_workers = config.get("min_workers", 0)

                cli_logger.print(
                    "{} random worker nodes will not be shut down. " +
                    cf.dimmed("(due to {})"), cf.bold(min_workers),
                    cf.bold("--keep-min-workers"))
                cli_logger.old_info(logger,
                                    "teardown_cluster: Keeping {} nodes...",
                                    min_workers)

                workers = random.sample(workers, len(workers) - min_workers)

            # todo: it's weird to kill the head node but not all workers
            if workers_only:
                cli_logger.print(
                    "The head node will not be shut down. " +
                    cf.dimmed("(due to {})"), cf.bold("--workers-only"))

                return workers

            head = provider.non_terminated_nodes(
                {TAG_RAY_NODE_KIND: NODE_KIND_HEAD})

            return head + workers

        def run_docker_stop(node, container_name):
            try:
                updater = NodeUpdaterThread(
                    node_id=node,
                    provider_config=config["provider"],
                    provider=provider,
                    auth_config=config["auth"],
                    cluster_name=config["cluster_name"],
                    file_mounts=config["file_mounts"],
                    initialization_commands=[],
                    setup_commands=[],
                    ray_start_commands=[],
                    runtime_hash="",
                    file_mounts_contents_hash="",
                    is_head_node=False,
                    docker_config=config.get("docker"))
                _exec(updater,
                      f"docker stop {container_name}",
                      False,
                      False,
                      run_env="host")
            except Exception:
                cli_logger.warning(f"Docker stop failed on {node}")
                cli_logger.old_warning(logger, f"Docker stop failed on {node}")

        # Loop here to check that both the head and worker nodes are actually
        #   really gone
        A = remaining_nodes()

        container_name = config.get("docker", {}).get("container_name")
        if container_name:
            for node in A:
                run_docker_stop(node, container_name)

        with LogTimer("teardown_cluster: done."):
            while A:
                cli_logger.old_info(
                    logger, "teardown_cluster: "
                    "Shutting down {} nodes...", len(A))

                provider.terminate_nodes(A)

                cli_logger.print("Requested {} nodes to shut down.",
                                 cf.bold(len(A)),
                                 _tags=dict(interval="1s"))

                time.sleep(1)  # todo: interval should be a variable
                A = remaining_nodes()
                cli_logger.print("{} nodes remaining after 1 second.",
                                 cf.bold(len(A)))
            cli_logger.success("No nodes remaining.")
    finally:
        provider.cleanup()
Esempio n. 28
0
File: updater.py Progetto: aeli0/ray
    def do_update(self):
        self.provider.set_node_tags(
            self.node_id, {TAG_RAY_NODE_STATUS: STATUS_WAITING_FOR_SSH})
        cli_logger.labeled_value("New status", STATUS_WAITING_FOR_SSH)

        deadline = time.time() + NODE_START_WAIT_S
        self.wait_ready(deadline)

        node_tags = self.provider.node_tags(self.node_id)
        logger.debug("Node tags: {}".format(str(node_tags)))

        # runtime_hash will only change whenever the user restarts
        # or updates their cluster with `get_or_create_head_node`
        if node_tags.get(TAG_RAY_RUNTIME_CONFIG) == self.runtime_hash and (
                self.file_mounts_contents_hash is None
                or node_tags.get(TAG_RAY_FILE_MOUNTS_CONTENTS) ==
                self.file_mounts_contents_hash):
            # todo: we lie in the confirmation message since
            # full setup might be cancelled here
            cli_logger.print(
                "Configuration already up to date, "
                "skipping file mounts, initalization and setup commands.")
            cli_logger.old_info(logger,
                                "{}{} already up-to-date, skip to ray start",
                                self.log_prefix, self.node_id)

        else:
            cli_logger.print(
                "Updating cluster configuration.",
                _tags=dict(hash=self.runtime_hash))

            self.provider.set_node_tags(
                self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SYNCING_FILES})
            cli_logger.labeled_value("New status", STATUS_SYNCING_FILES)
            self.sync_file_mounts(self.rsync_up)

            # Only run setup commands if runtime_hash has changed because
            # we don't want to run setup_commands every time the head node
            # file_mounts folders have changed.
            if node_tags.get(TAG_RAY_RUNTIME_CONFIG) != self.runtime_hash:
                # Run init commands
                self.provider.set_node_tags(
                    self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SETTING_UP})
                cli_logger.labeled_value("New status", STATUS_SETTING_UP)

                if self.initialization_commands:
                    with cli_logger.group(
                            "Running initialization commands",
                            _numbered=("[]", 4,
                                       6)):  # todo: fix command numbering
                        with LogTimer(
                                self.log_prefix + "Initialization commands",
                                show_status=True):

                            for cmd in self.initialization_commands:
                                self.cmd_runner.run(
                                    cmd,
                                    ssh_options_override=SSHOptions(
                                        self.auth_config.get(
                                            "ssh_private_key")))
                else:
                    cli_logger.print(
                        "No initialization commands to run.",
                        _numbered=("[]", 4, 6))

                if self.setup_commands:
                    with cli_logger.group(
                            "Running setup commands",
                            _numbered=("[]", 5,
                                       6)):  # todo: fix command numbering
                        with LogTimer(
                                self.log_prefix + "Setup commands",
                                show_status=True):

                            total = len(self.setup_commands)
                            for i, cmd in enumerate(self.setup_commands):
                                if cli_logger.verbosity == 0:
                                    cmd_to_print = cf.bold(cmd[:30]) + "..."
                                else:
                                    cmd_to_print = cf.bold(cmd)

                                cli_logger.print(
                                    "{}",
                                    cmd_to_print,
                                    _numbered=("()", i, total))

                                self.cmd_runner.run(cmd)
                else:
                    cli_logger.print(
                        "No setup commands to run.", _numbered=("[]", 5, 6))

        with cli_logger.group(
                "Starting the Ray runtime", _numbered=("[]", 6, 6)):
            with LogTimer(
                    self.log_prefix + "Ray start commands", show_status=True):
                for cmd in self.ray_start_commands:
                    self.cmd_runner.run(cmd)
Esempio n. 29
0
    def run(self):
        cli_logger.old_info(logger, "{}Updating to {}", self.log_prefix,
                            self.runtime_hash)

        if cmd_output_util.does_allow_interactive(
        ) and cmd_output_util.is_output_redirected():
            # this is most probably a bug since the user has no control
            # over these settings
            msg = ("Output was redirected for an interactive command. "
                   "Either do not pass `--redirect-command-output` "
                   "or also pass in `--use-normal-shells`.")
            cli_logger.abort(msg)
            raise click.ClickException(msg)

        try:
            with LogTimer(self.log_prefix +
                          "Applied config {}".format(self.runtime_hash)):
                self.do_update()
        except Exception as e:
            error_str = str(e)
            if hasattr(e, "cmd"):
                error_str = "(Exit Status {}) {}".format(
                    e.returncode, " ".join(e.cmd))

            self.provider.set_node_tags(
                self.node_id, {TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED})
            cli_logger.error("New status: {}", cf.bold(STATUS_UPDATE_FAILED))

            cli_logger.old_error(logger, "{}Error executing: {}\n",
                                 self.log_prefix, error_str)

            cli_logger.error("!!!")
            if hasattr(e, "cmd"):
                cli_logger.error(
                    "Setup command `{}` failed with exit code {}. stderr:",
                    cf.bold(e.cmd), e.returncode)
            else:
                cli_logger.verbose_error("{}", str(vars(e)))
                # todo: handle this better somehow?
                cli_logger.error("{}", str(e))
            # todo: print stderr here
            cli_logger.error("!!!")
            cli_logger.newline()

            if isinstance(e, click.ClickException):
                # todo: why do we ignore this here
                return
            raise

        tags_to_set = {
            TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
            TAG_RAY_RUNTIME_CONFIG: self.runtime_hash,
        }
        if self.file_mounts_contents_hash is not None:
            tags_to_set[
                TAG_RAY_FILE_MOUNTS_CONTENTS] = self.file_mounts_contents_hash

        self.provider.set_node_tags(self.node_id, tags_to_set)
        cli_logger.labeled_value("New status", STATUS_UP_TO_DATE)

        self.exitcode = 0
Esempio n. 30
0
    def create_node(self, node_config, tags, count):
        tags = copy.deepcopy(tags)
        # Try to reuse previously stopped nodes with compatible configs
        if self.cache_stopped_nodes:
            # TODO(ekl) this is breaking the abstraction boundary a little by
            # peeking into the tag set.
            filters = [
                {
                    "Name": "instance-state-name",
                    "Values": ["stopped", "stopping"],
                },
                {
                    "Name": "tag:{}".format(TAG_RAY_CLUSTER_NAME),
                    "Values": [self.cluster_name],
                },
                {
                    "Name": "tag:{}".format(TAG_RAY_NODE_KIND),
                    "Values": [tags[TAG_RAY_NODE_KIND]],
                },
                {
                    "Name": "tag:{}".format(TAG_RAY_LAUNCH_CONFIG),
                    "Values": [tags[TAG_RAY_LAUNCH_CONFIG]],
                },
            ]
            # This tag may not always be present.
            if TAG_RAY_USER_NODE_TYPE in tags:
                filters.append({
                    "Name": "tag:{}".format(TAG_RAY_USER_NODE_TYPE),
                    "Values": [tags[TAG_RAY_USER_NODE_TYPE]],
                })

            reuse_nodes = list(
                self.ec2.instances.filter(Filters=filters))[:count]
            reuse_node_ids = [n.id for n in reuse_nodes]
            if reuse_nodes:
                cli_logger.print(
                    # todo: handle plural vs singular?
                    "Reusing nodes {}. "
                    "To disable reuse, set `cache_stopped_nodes: False` "
                    "under `provider` in the cluster configuration.",
                    cli_logger.render_list(reuse_node_ids))
                cli_logger.old_info(
                    logger, "AWSNodeProvider: reusing instances {}. "
                    "To disable reuse, set "
                    "'cache_stopped_nodes: False' in the provider "
                    "config.", reuse_node_ids)

                # todo: timed?
                with cli_logger.group("Stopping instances to reuse"):
                    for node in reuse_nodes:
                        self.tag_cache[node.id] = from_aws_format(
                            {x["Key"]: x["Value"]
                             for x in node.tags})
                        if node.state["Name"] == "stopping":
                            cli_logger.print("Waiting for instance {} to stop",
                                             node.id)
                            cli_logger.old_info(
                                logger,
                                "AWSNodeProvider: waiting for instance "
                                "{} to fully stop...", node.id)
                            node.wait_until_stopped()

                self.ec2.meta.client.start_instances(
                    InstanceIds=reuse_node_ids)
                for node_id in reuse_node_ids:
                    self.set_node_tags(node_id, tags)
                count -= len(reuse_node_ids)

        if count:
            self._create_node(node_config, tags, count)