예제 #1
0
def _bootstrap_config(config: Dict[str, Any],
                      no_config_cache: bool = False) -> Dict[str, Any]:
    config = prepare_config(config)

    hasher = hashlib.sha1()
    hasher.update(json.dumps([config], sort_keys=True).encode("utf-8"))
    cache_key = os.path.join(tempfile.gettempdir(),
                             "ray-config-{}".format(hasher.hexdigest()))

    if os.path.exists(cache_key) and not no_config_cache:
        cli_logger.old_info(logger, "Using cached config at {}", cache_key)

        config_cache = json.loads(open(cache_key).read())
        if config_cache.get("_version", -1) == CONFIG_CACHE_VERSION:
            # todo: is it fine to re-resolve? afaik it should be.
            # we can have migrations otherwise or something
            # but this seems overcomplicated given that resolving is
            # relatively cheap
            try_reload_log_state(config_cache["config"]["provider"],
                                 config_cache.get("provider_log_info"))
            cli_logger.verbose("Loaded cached config from " + cf.bold("{}"),
                               cache_key)

            return config_cache["config"]
        else:
            cli_logger.warning(
                "Found cached cluster config "
                "but the version " + cf.bold("{}") + " "
                "(expected " + cf.bold("{}") + ") does not match.\n"
                "This is normal if cluster launcher was updated.\n"
                "Config will be re-resolved.",
                config_cache.get("_version", "none"), CONFIG_CACHE_VERSION)
    validate_config(config)

    importer = NODE_PROVIDERS.get(config["provider"]["type"])
    if not importer:
        raise NotImplementedError("Unsupported provider {}".format(
            config["provider"]))

    provider_cls = importer(config["provider"])

    with cli_logger.timed(  # todo: better message
            "Bootstraping {} config",
            PROVIDER_PRETTY_NAMES.get(config["provider"]["type"])):
        resolved_config = provider_cls.bootstrap_config(config)

    if not no_config_cache:
        with open(cache_key, "w") as f:
            config_cache = {
                "_version": CONFIG_CACHE_VERSION,
                "provider_log_info": try_get_log_state(config["provider"]),
                "config": resolved_config
            }
            f.write(json.dumps(config_cache))
    return resolved_config
예제 #2
0
def log_to_cli(config):
    provider_name = PROVIDER_PRETTY_NAMES.get("aws", None)

    cli_logger.doassert(provider_name is not None,
                        "Could not find a pretty name for the AWS provider.")

    with cli_logger.group("{} config", provider_name):

        def same_everywhere(key):
            return config["head_node"][key] == config["worker_nodes"][key]

        def print_info(resource_string,
                       key,
                       head_src_key,
                       workers_src_key,
                       allowed_tags=["default"],
                       list_value=False):

            head_tags = {}
            workers_tags = {}

            if _log_info[head_src_key] in allowed_tags:
                head_tags[_log_info[head_src_key]] = True
            if _log_info[workers_src_key] in allowed_tags:
                workers_tags[_log_info[workers_src_key]] = True

            head_value_str = config["head_node"][key]
            if list_value:
                head_value_str = cli_logger.render_list(head_value_str)

            if same_everywhere(key):
                cli_logger.labeled_value(  # todo: handle plural vs singular?
                    resource_string + " (head & workers)",
                    "{}",
                    head_value_str,
                    _tags=head_tags)
            else:
                workers_value_str = config["worker_nodes"][key]
                if list_value:
                    workers_value_str = cli_logger.render_list(
                        workers_value_str)

                cli_logger.labeled_value(resource_string + " (head)",
                                         "{}",
                                         head_value_str,
                                         _tags=head_tags)
                cli_logger.labeled_value(resource_string + " (workers)",
                                         "{}",
                                         workers_value_str,
                                         _tags=workers_tags)

        tags = {"default": _log_info["head_instance_profile_src"] == "default"}
        cli_logger.labeled_value(
            "IAM Profile",
            "{}",
            _arn_to_name(config["head_node"]["IamInstanceProfile"]["Arn"]),
            _tags=tags)

        print_info("EC2 Key pair", "KeyName", "keypair_src", "keypair_src")
        print_info("VPC Subnets",
                   "SubnetIds",
                   "head_subnet_src",
                   "workers_subnet_src",
                   list_value=True)
        print_info("EC2 Security groups",
                   "SecurityGroupIds",
                   "head_security_group_src",
                   "workers_security_group_src",
                   list_value=True)
        print_info("EC2 AMI",
                   "ImageId",
                   "head_ami_src",
                   "workers_ami_src",
                   allowed_tags=["dlami"])

    cli_logger.newline()