예제 #1
0
 def testValidateDefaultConfig(self):
     for config_path in CONFIG_PATHS:
         with open(config_path) as f:
             config = yaml.safe_load(f)
         config = fillout_defaults(config)
         try:
             validate_config(config)
         except Exception:
             self.fail("Config did not pass validation test!")
예제 #2
0
 def testValidateDefaultConfig(self):
     config = {}
     config["provider"] = {
         "type": "aws",
         "region": "us-east-1",
         "availability_zone": "us-east-1a",
     }
     config = fillout_defaults(config)
     try:
         validate_config(config)
     except ValidationError:
         self.fail("Default config did not pass validation test!")
예제 #3
0
 def testValidateNetworkConfig(self):
     web_yaml = "https://raw.githubusercontent.com/ray-project/ray/" \
         "master/python/ray/autoscaler/aws/example-full.yaml"
     response = urllib.request.urlopen(web_yaml, timeout=5)
     content = response.read()
     with tempfile.TemporaryFile() as f:
         f.write(content)
         f.seek(0)
         config = yaml.safe_load(f)
     config = fillout_defaults(config)
     try:
         validate_config(config)
     except Exception:
         self.fail("Config did not pass validation test!")
예제 #4
0
파일: commands.py 프로젝트: aniryou/ray
def teardown_cluster(config_file, yes, workers_only, override_cluster_name,
                     keep_min_workers):
    """Destroys all nodes of a Ray cluster described by a config json."""

    config = yaml.safe_load(open(config_file).read())
    if override_cluster_name is not None:
        config["cluster_name"] = override_cluster_name
    config = fillout_defaults(config)
    validate_config(config)

    confirm("This will destroy your cluster", yes)

    provider = get_node_provider(config["provider"], config["cluster_name"])
    try:

        def remaining_nodes():

            workers = provider.non_terminated_nodes({
                TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER
            })

            if keep_min_workers:
                min_workers = config.get("min_workers", 0)
                logger.info("teardown_cluster: "
                            "Keeping {} nodes...".format(min_workers))
                workers = random.sample(workers, len(workers) - min_workers)

            if workers_only:
                return workers

            head = provider.non_terminated_nodes({
                TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD
            })

            return head + workers

        # Loop here to check that both the head and worker nodes are actually
        #   really gone
        A = remaining_nodes()
        with LogTimer("teardown_cluster: done."):
            while A:
                logger.info("teardown_cluster: "
                            "Shutting down {} nodes...".format(len(A)))
                provider.terminate_nodes(A)
                time.sleep(1)
                A = remaining_nodes()
    finally:
        provider.cleanup()
예제 #5
0
 def testReportsConfigFailures(self):
     config = copy.deepcopy(SMALL_CLUSTER)
     config["provider"]["type"] = "external"
     config = fillout_defaults(config)
     config["provider"]["type"] = "mock"
     config_path = self.write_config(config)
     self.provider = MockProvider()
     runner = MockProcessRunner(fail_cmds=["setup_cmd"])
     autoscaler = StandardAutoscaler(
         config_path,
         LoadMetrics(),
         max_failures=0,
         process_runner=runner,
         update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     self.waitForNodes(2)
     self.provider.finish_starting_nodes()
     autoscaler.update()
     self.waitForNodes(
         2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED})
예제 #6
0
파일: commands.py 프로젝트: aniryou/ray
def _bootstrap_config(config):
    config = fillout_defaults(config)

    hasher = hashlib.sha1()
    hasher.update(json.dumps([config], sort_keys=True).encode("utf-8"))
    cache_key = os.path.join(tempfile.gettempdir(),
                             "ray-config-{}".format(hasher.hexdigest()))
    if os.path.exists(cache_key):
        logger.info("Using cached config at {}".format(cache_key))
        return json.loads(open(cache_key).read())
    validate_config(config)

    importer = NODE_PROVIDERS.get(config["provider"]["type"])
    if not importer:
        raise NotImplementedError("Unsupported provider {}".format(
            config["provider"]))

    bootstrap_config, _ = importer()
    resolved_config = bootstrap_config(config)
    with open(cache_key, "w") as f:
        f.write(json.dumps(resolved_config))
    return resolved_config