Exemplo n.º 1
0
    def reset(self, errors_fatal=False):
        sync_continuously = False
        if hasattr(self, "config"):
            sync_continuously = self.config.get(
                "file_mounts_sync_continuously", False)
        try:
            with open(self.config_path) as f:
                new_config = yaml.safe_load(f.read())
            if new_config != getattr(self, "config", None):
                try:
                    validate_config(new_config)
                except Exception as e:
                    self.prom_metrics.config_validation_exceptions.inc()
                    logger.debug(
                        "Cluster config validation failed. The version of "
                        "the ray CLI you launched this cluster with may "
                        "be higher than the version of ray being run on "
                        "the cluster. Some new features may not be "
                        "available until you upgrade ray on your cluster.",
                        exc_info=e)
            (new_runtime_hash,
             new_file_mounts_contents_hash) = hash_runtime_conf(
                 new_config["file_mounts"],
                 new_config["cluster_synced_files"],
                 [
                     new_config["worker_setup_commands"],
                     new_config["worker_start_ray_commands"],
                 ],
                 generate_file_mounts_contents_hash=sync_continuously,
             )
            self.config = new_config
            self.runtime_hash = new_runtime_hash
            self.file_mounts_contents_hash = new_file_mounts_contents_hash
            if not self.provider:
                self.provider = _get_node_provider(self.config["provider"],
                                                   self.config["cluster_name"])

            # If using the LocalNodeProvider, make sure the head node is marked
            # non-terminated.
            if isinstance(self.provider, LocalNodeProvider):
                record_local_head_state_if_needed(self.provider)

            self.available_node_types = self.config["available_node_types"]
            upscaling_speed = self.config.get("upscaling_speed")
            aggressive = self.config.get("autoscaling_mode") == "aggressive"
            target_utilization_fraction = self.config.get(
                "target_utilization_fraction")
            if upscaling_speed:
                upscaling_speed = float(upscaling_speed)
            # TODO(ameer): consider adding (if users ask) an option of
            # initial_upscaling_num_workers.
            elif aggressive:
                upscaling_speed = 99999
                logger.warning(
                    "Legacy aggressive autoscaling mode "
                    "detected. Replacing it by setting upscaling_speed to "
                    "99999.")
            elif target_utilization_fraction:
                upscaling_speed = (
                    1 / max(target_utilization_fraction, 0.001) - 1)
                logger.warning(
                    "Legacy target_utilization_fraction config "
                    "detected. Replacing it by setting upscaling_speed to " +
                    "1 / target_utilization_fraction - 1.")
            else:
                upscaling_speed = 1.0
            if self.resource_demand_scheduler:
                # The node types are autofilled internally for legacy yamls,
                # overwriting the class will remove the inferred node resources
                # for legacy yamls.
                self.resource_demand_scheduler.reset_config(
                    self.provider, self.available_node_types,
                    self.config["max_workers"], self.config["head_node_type"],
                    upscaling_speed)
            else:
                self.resource_demand_scheduler = ResourceDemandScheduler(
                    self.provider, self.available_node_types,
                    self.config["max_workers"], self.config["head_node_type"],
                    upscaling_speed)

        except Exception as e:
            self.prom_metrics.reset_exceptions.inc()
            if errors_fatal:
                raise e
            else:
                logger.exception("StandardAutoscaler: "
                                 "Error parsing config.")
Exemplo n.º 2
0
    def testClusterStateInit(self):
        """Check ClusterState __init__ func generates correct state file.

        Test the general use case and if num_workers increase/decrease.
        """
        # Use a random head_ip so that the state file is regenerated each time
        # this test is run. (Otherwise the test will fail spuriously when run a
        # second time.)
        self._monkeypatch.setenv("RAY_TMPDIR", self._tmpdir)
        # ensure that a new cluster can start up if RAY_TMPDIR doesn't exist yet
        assert not os.path.exists(get_ray_temp_dir())
        head_ip = ".".join(str(random.randint(0, 255)) for _ in range(4))
        cluster_config = {
            "cluster_name": "random_name",
            "min_workers": 0,
            "max_workers": 0,
            "provider": {
                "type": "local",
                "head_ip": head_ip,
                "worker_ips": ["0.0.0.0:1"],
                "external_head_ip": "0.0.0.0.3",
            },
        }
        provider_config = cluster_config["provider"]
        node_provider = _get_node_provider(
            provider_config, cluster_config["cluster_name"], use_cache=False
        )
        assert os.path.exists(get_ray_temp_dir())
        assert node_provider.external_ip(head_ip) == "0.0.0.0.3"
        assert isinstance(node_provider, LocalNodeProvider)
        expected_workers = {}
        expected_workers[provider_config["head_ip"]] = {
            "tags": {TAG_RAY_NODE_KIND: NODE_KIND_HEAD},
            "state": "terminated",
            "external_ip": "0.0.0.0.3",
        }
        expected_workers[provider_config["worker_ips"][0]] = {
            "tags": {TAG_RAY_NODE_KIND: NODE_KIND_WORKER},
            "state": "terminated",
        }

        state_save_path = local_config.get_state_path(cluster_config["cluster_name"])

        assert os.path.exists(state_save_path)
        workers = json.loads(open(state_save_path).read())
        assert workers == expected_workers

        # Test removing workers updates the cluster state.
        del expected_workers[provider_config["worker_ips"][0]]
        removed_ip = provider_config["worker_ips"].pop()
        node_provider = _get_node_provider(
            provider_config, cluster_config["cluster_name"], use_cache=False
        )
        workers = json.loads(open(state_save_path).read())
        assert workers == expected_workers

        # Test adding back workers updates the cluster state.
        expected_workers[removed_ip] = {
            "tags": {TAG_RAY_NODE_KIND: NODE_KIND_WORKER},
            "state": "terminated",
        }
        provider_config["worker_ips"].append(removed_ip)
        node_provider = _get_node_provider(
            provider_config, cluster_config["cluster_name"], use_cache=False
        )
        workers = json.loads(open(state_save_path).read())
        assert workers == expected_workers

        # Test record_local_head_state_if_needed
        head_ip = cluster_config["provider"]["head_ip"]
        cluster_name = cluster_config["cluster_name"]
        node_provider = _get_node_provider(
            provider_config, cluster_config["cluster_name"], use_cache=False
        )
        assert head_ip not in node_provider.non_terminated_nodes({})
        record_local_head_state_if_needed(node_provider)
        assert head_ip in node_provider.non_terminated_nodes({})
        expected_head_tags = {
            TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
            TAG_RAY_USER_NODE_TYPE: local_config.LOCAL_CLUSTER_NODE_TYPE,
            TAG_RAY_NODE_NAME: "ray-{}-head".format(cluster_name),
            TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
        }
        assert node_provider.node_tags(head_ip) == expected_head_tags
        # Repeat and verify nothing has changed.
        record_local_head_state_if_needed(node_provider)
        assert head_ip in node_provider.non_terminated_nodes({})
        assert node_provider.node_tags(head_ip) == expected_head_tags