示例#1
0
    def testDelayedLaunch(self):
        config_path = self.write_config(SMALL_CLUSTER)
        self.provider = MockProvider()
        autoscaler = StandardAutoscaler(
            config_path,
            LoadMetrics(),
            max_launch_batch=5,
            max_concurrent_launches=5,
            max_failures=0,
            update_interval_s=0)
        assert len(self.provider.non_terminated_nodes({})) == 0

        # Update will try to create, but will block until we set the flag
        self.provider.ready_to_create.clear()
        autoscaler.update()
        assert autoscaler.num_launches_pending.value == 2
        assert len(self.provider.non_terminated_nodes({})) == 0

        # Set the flag, check it updates
        self.provider.ready_to_create.set()
        self.waitForNodes(2)
        assert autoscaler.num_launches_pending.value == 0

        # Update the config to reduce the cluster size
        new_config = SMALL_CLUSTER.copy()
        new_config["max_workers"] = 1
        self.write_config(new_config)
        autoscaler.update()
        assert len(self.provider.non_terminated_nodes({})) == 1
示例#2
0
 def testMaxFailures(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     self.provider.throw = True
     autoscaler = StandardAutoscaler(
         config_path, LoadMetrics(), max_failures=2, update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     self.assertRaises(Exception, autoscaler.update)
示例#3
0
 def testScaleUp(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     autoscaler = StandardAutoscaler(
         config_path, LoadMetrics(), max_failures=0, update_interval_s=0)
     assert len(self.provider.non_terminated_nodes({})) == 0
     autoscaler.update()
     self.waitForNodes(2)
     autoscaler.update()
     self.waitForNodes(2)
示例#4
0
 def testScaleUp(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     autoscaler = StandardAutoscaler(
         config_path, LoadMetrics(), max_failures=0, update_interval_s=0)
     self.assertEqual(len(self.provider.nodes({})), 0)
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)
示例#5
0
    def testDontScaleBelowTarget(self):
        config = SMALL_CLUSTER.copy()
        config["min_workers"] = 0
        config["max_workers"] = 2
        config["target_utilization_fraction"] = 0.5
        config_path = self.write_config(config)
        self.provider = MockProvider()
        lm = LoadMetrics()
        autoscaler = StandardAutoscaler(
            config_path, lm, max_failures=0, update_interval_s=0)
        assert len(self.provider.non_terminated_nodes({})) == 0
        autoscaler.update()
        assert autoscaler.num_launches_pending.value == 0
        assert len(self.provider.non_terminated_nodes({})) == 0

        # Scales up as nodes are reported as used
        local_ip = services.get_node_ip_address()
        lm.update(local_ip, {"CPU": 2}, {"CPU": 0})  # head
        # 1.0 nodes used => target nodes = 2 => target workers = 1
        autoscaler.update()
        self.waitForNodes(1)

        # Make new node idle, and never used.
        # Should hold steady as target is still 2.
        lm.update("172.0.0.0", {"CPU": 0}, {"CPU": 0})
        lm.last_used_time_by_ip["172.0.0.0"] = 0
        autoscaler.update()
        assert len(self.provider.non_terminated_nodes({})) == 1

        # Reduce load on head => target nodes = 1 => target workers = 0
        lm.update(local_ip, {"CPU": 2}, {"CPU": 1})
        autoscaler.update()
        assert len(self.provider.non_terminated_nodes({})) == 0
示例#6
0
    def testDynamicScaling(self):
        config_path = self.write_config(SMALL_CLUSTER)
        self.provider = MockProvider()
        autoscaler = StandardAutoscaler(
            config_path, LoadMetrics(), max_concurrent_launches=5,
            max_failures=0, update_interval_s=0)
        self.assertEqual(len(self.provider.nodes({})), 0)
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 2)

        # Update the config to reduce the cluster size
        new_config = SMALL_CLUSTER.copy()
        new_config["max_workers"] = 1
        self.write_config(new_config)
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 1)

        # Update the config to reduce the cluster size
        new_config["min_workers"] = 10
        new_config["max_workers"] = 10
        self.write_config(new_config)
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 6)
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 10)
示例#7
0
 def testUpdateThrottling(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     autoscaler = StandardAutoscaler(
         config_path, LoadMetrics(), max_concurrent_launches=5,
         max_failures=0, update_interval_s=10)
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)
     new_config = SMALL_CLUSTER.copy()
     new_config["max_workers"] = 1
     self.write_config(new_config)
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)  # not updated yet
示例#8
0
 def testLaunchNewNodeOnOutOfBandTerminate(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     autoscaler = StandardAutoscaler(
         config_path, LoadMetrics(), max_failures=0, update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)
     for node in self.provider.mock_nodes.values():
         node.state = "terminated"
     self.assertEqual(len(self.provider.nodes({})), 0)
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)
示例#9
0
    def testDelayedLaunchWithFailure(self):
        config = SMALL_CLUSTER.copy()
        config["min_workers"] = 10
        config["max_workers"] = 10
        config_path = self.write_config(config)
        self.provider = MockProvider()
        autoscaler = StandardAutoscaler(
            config_path,
            LoadMetrics(),
            max_launch_batch=5,
            max_concurrent_launches=8,
            max_failures=0,
            update_interval_s=0)
        assert len(self.provider.non_terminated_nodes({})) == 0

        # update() should launch a wave of 5 nodes (max_launch_batch)
        # Force this first wave to block.
        rtc1 = self.provider.ready_to_create
        rtc1.clear()
        autoscaler.update()
        # Synchronization: wait for launchy thread to be blocked on rtc1
        if hasattr(rtc1, '_cond'):  # Python 3.5
            waiters = rtc1._cond._waiters
        else:  # Python 2.7
            waiters = rtc1._Event__cond._Condition__waiters
        self.waitFor(lambda: len(waiters) == 1)
        assert autoscaler.num_launches_pending.value == 5
        assert len(self.provider.non_terminated_nodes({})) == 0

        # Call update() to launch a second wave of 3 nodes,
        # as 5 + 3 = 8 = max_concurrent_launches.
        # Make this wave complete immediately.
        rtc2 = threading.Event()
        self.provider.ready_to_create = rtc2
        rtc2.set()
        autoscaler.update()
        self.waitForNodes(3)
        assert autoscaler.num_launches_pending.value == 5

        # The first wave of 5 will now tragically fail
        self.provider.fail_creates = True
        rtc1.set()
        self.waitFor(lambda: autoscaler.num_launches_pending.value == 0)
        assert len(self.provider.non_terminated_nodes({})) == 3

        # Retry the first wave, allowing it to succeed this time
        self.provider.fail_creates = False
        autoscaler.update()
        self.waitForNodes(8)
        assert autoscaler.num_launches_pending.value == 0

        # Final wave of 2 nodes
        autoscaler.update()
        self.waitForNodes(10)
        assert autoscaler.num_launches_pending.value == 0
示例#10
0
 def testConfiguresNewNodes(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     runner = MockProcessRunner()
     autoscaler = StandardAutoscaler(
         config_path,
         LoadMetrics(),
         max_failures=0,
         process_runner=runner,
         update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     self.waitForNodes(2)
     for node in self.provider.mock_nodes.values():
         node.state = "running"
     autoscaler.update()
     self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"})
示例#11
0
    def testIgnoresCorruptedConfig(self):
        config_path = self.write_config(SMALL_CLUSTER)
        self.provider = MockProvider()
        autoscaler = StandardAutoscaler(
            config_path,
            LoadMetrics(),
            max_launch_batch=10,
            max_concurrent_launches=10,
            max_failures=0,
            update_interval_s=0)
        autoscaler.update()
        self.waitForNodes(2)

        # Write a corrupted config
        self.write_config("asdf")
        for _ in range(10):
            autoscaler.update()
        time.sleep(0.1)
        assert autoscaler.num_launches_pending.value == 0
        assert len(self.provider.non_terminated_nodes({})) == 2

        # New a good config again
        new_config = SMALL_CLUSTER.copy()
        new_config["min_workers"] = 10
        new_config["max_workers"] = 10
        self.write_config(new_config)
        autoscaler.update()
        self.waitForNodes(10)
示例#12
0
    def testRecoverUnhealthyWorkers(self):
        config_path = self.write_config(SMALL_CLUSTER)
        self.provider = MockProvider()
        runner = MockProcessRunner()
        lm = LoadMetrics()
        autoscaler = StandardAutoscaler(
            config_path,
            lm,
            max_failures=0,
            process_runner=runner,
            verbose_updates=True,
            node_updater_cls=NodeUpdaterThread,
            update_interval_s=0)
        autoscaler.update()
        self.waitForNodes(2)
        for node in self.provider.mock_nodes.values():
            node.state = "running"
        autoscaler.update()
        self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"})

        # Mark a node as unhealthy
        lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0
        num_calls = len(runner.calls)
        autoscaler.update()
        self.waitFor(lambda: len(runner.calls) > num_calls)
示例#13
0
 def testConfiguresOutdatedNodes(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     runner = MockProcessRunner()
     autoscaler = StandardAutoscaler(
         config_path,
         LoadMetrics(),
         max_failures=0,
         process_runner=runner,
         verbose_updates=True,
         node_updater_cls=NodeUpdaterThread,
         update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     self.waitForNodes(2)
     for node in self.provider.mock_nodes.values():
         node.state = "running"
     autoscaler.update()
     self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"})
     runner.calls = []
     new_config = SMALL_CLUSTER.copy()
     new_config["worker_setup_commands"] = ["cmdX", "cmdY"]
     self.write_config(new_config)
     autoscaler.update()
     autoscaler.update()
     self.waitFor(lambda: len(runner.calls) > 0)
示例#14
0
    def testLaunchConfigChange(self):
        config_path = self.write_config(SMALL_CLUSTER)
        self.provider = MockProvider()
        autoscaler = StandardAutoscaler(
            config_path, LoadMetrics(), max_failures=0, update_interval_s=0)
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 2)

        # Update the config to change the node type
        new_config = SMALL_CLUSTER.copy()
        new_config["worker_nodes"]["InstanceType"] = "updated"
        self.write_config(new_config)
        existing_nodes = set(self.provider.nodes({}))
        for _ in range(5):
            autoscaler.update()
        new_nodes = set(self.provider.nodes({}))
        self.assertEqual(len(new_nodes), 2)
        self.assertEqual(len(new_nodes.intersection(existing_nodes)), 0)
示例#15
0
    def testLaunchConfigChange(self):
        config_path = self.write_config(SMALL_CLUSTER)
        self.provider = MockProvider()
        autoscaler = StandardAutoscaler(
            config_path, LoadMetrics(), max_failures=0, update_interval_s=0)
        autoscaler.update()
        self.waitForNodes(2)

        # Update the config to change the node type
        new_config = SMALL_CLUSTER.copy()
        new_config["worker_nodes"]["InstanceType"] = "updated"
        self.write_config(new_config)
        self.provider.ready_to_create.clear()
        for _ in range(5):
            autoscaler.update()
        self.waitForNodes(0)
        self.provider.ready_to_create.set()
        self.waitForNodes(2)
示例#16
0
 def testReportsConfigFailures(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     runner = MockProcessRunner(fail_cmds=["cmd1"])
     autoscaler = StandardAutoscaler(
         config_path, LoadMetrics(), max_failures=0, process_runner=runner,
         verbose_updates=True, node_updater_cls=NodeUpdaterThread,
         update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)
     for node in self.provider.mock_nodes.values():
         node.state = "running"
     assert len(self.provider.nodes(
         {TAG_RAY_NODE_STATUS: "Uninitialized"})) == 2
     autoscaler.update()
     self.waitFor(
         lambda: len(self.provider.nodes(
             {TAG_RAY_NODE_STATUS: "UpdateFailed"})) == 2)
示例#17
0
    def testTerminateOutdatedNodesGracefully(self):
        config = SMALL_CLUSTER.copy()
        config["min_workers"] = 5
        config["max_workers"] = 5
        config_path = self.write_config(config)
        self.provider = MockProvider()
        self.provider.create_node({}, {TAG_RAY_NODE_TYPE: "Worker"}, 10)
        autoscaler = StandardAutoscaler(
            config_path, LoadMetrics(), max_failures=0, update_interval_s=0)
        self.assertEqual(len(self.provider.nodes({})), 10)

        # Gradually scales down to meet target size, never going too low
        for _ in range(10):
            autoscaler.update()
            self.assertLessEqual(len(self.provider.nodes({})), 5)
            self.assertGreaterEqual(len(self.provider.nodes({})), 4)

        # Eventually reaches steady state
        self.assertEqual(len(self.provider.nodes({})), 5)
示例#18
0
 def testUpdateThrottling(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     autoscaler = StandardAutoscaler(
         config_path,
         LoadMetrics(),
         max_launch_batch=5,
         max_concurrent_launches=5,
         max_failures=0,
         update_interval_s=10)
     autoscaler.update()
     self.waitForNodes(2)
     assert autoscaler.num_launches_pending.value == 0
     new_config = SMALL_CLUSTER.copy()
     new_config["max_workers"] = 1
     self.write_config(new_config)
     autoscaler.update()
     # not updated yet
     # note that node termination happens in the main thread, so
     # we do not need to add any delay here before checking
     assert len(self.provider.non_terminated_nodes({})) == 2
     assert autoscaler.num_launches_pending.value == 0
示例#19
0
    def testSetupCommandsWithNoNodeCaching(self):
        config = SMALL_CLUSTER.copy()
        config["min_workers"] = 1
        config["max_workers"] = 1
        config_path = self.write_config(config)
        self.provider = MockProvider(cache_stopped=False)
        runner = MockProcessRunner()
        lm = LoadMetrics()
        autoscaler = StandardAutoscaler(
            config_path,
            lm,
            max_failures=0,
            process_runner=runner,
            update_interval_s=0)
        autoscaler.update()
        self.waitForNodes(1)
        self.provider.finish_starting_nodes()
        autoscaler.update()
        self.waitForNodes(
            1, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
        runner.assert_has_call("172.0.0.0", "init_cmd")
        runner.assert_has_call("172.0.0.0", "setup_cmd")
        runner.assert_has_call("172.0.0.0", "worker_setup_cmd")
        runner.assert_has_call("172.0.0.0", "start_ray_worker")

        # Check the node was not reused
        self.provider.terminate_node(0)
        autoscaler.update()
        self.waitForNodes(1)
        runner.clear_history()
        self.provider.finish_starting_nodes()
        autoscaler.update()
        self.waitForNodes(
            1, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
        runner.assert_has_call("172.0.0.1", "init_cmd")
        runner.assert_has_call("172.0.0.1", "setup_cmd")
        runner.assert_has_call("172.0.0.1", "worker_setup_cmd")
        runner.assert_has_call("172.0.0.1", "start_ray_worker")
示例#20
0
    def testDontScaleBelowTarget(self):
        config = SMALL_CLUSTER.copy()
        config["min_workers"] = 0
        config["max_workers"] = 2
        config["target_utilization_fraction"] = 0.5
        config_path = self.write_config(config)
        self.provider = MockProvider()
        lm = LoadMetrics()
        runner = MockProcessRunner()
        autoscaler = StandardAutoscaler(
            config_path,
            lm,
            max_failures=0,
            process_runner=runner,
            update_interval_s=0)
        assert len(self.provider.non_terminated_nodes({})) == 0
        autoscaler.update()
        assert autoscaler.pending_launches.value == 0
        assert len(self.provider.non_terminated_nodes({})) == 0

        # Scales up as nodes are reported as used
        local_ip = services.get_node_ip_address()
        lm.update(local_ip, {"CPU": 2}, {"CPU": 0}, {})  # head
        # 1.0 nodes used => target nodes = 2 => target workers = 1
        autoscaler.update()
        self.waitForNodes(1)

        # Make new node idle, and never used.
        # Should hold steady as target is still 2.
        lm.update("172.0.0.0", {"CPU": 0}, {"CPU": 0}, {})
        lm.last_used_time_by_ip["172.0.0.0"] = 0
        autoscaler.update()
        assert len(self.provider.non_terminated_nodes({})) == 1

        # Reduce load on head => target nodes = 1 => target workers = 0
        lm.update(local_ip, {"CPU": 2}, {"CPU": 1}, {})
        autoscaler.update()
        assert len(self.provider.non_terminated_nodes({})) == 0
示例#21
0
 def testConfiguresNewNodes(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     runner = MockProcessRunner()
     autoscaler = StandardAutoscaler(
         config_path,
         LoadMetrics(),
         max_failures=0,
         process_runner=runner,
         verbose_updates=True,
         node_updater_cls=NodeUpdaterThread,
         update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     self.waitForNodes(2)
     for node in self.provider.mock_nodes.values():
         node.state = "running"
     assert len(
         self.provider.nodes({
             TAG_RAY_NODE_STATUS: "uninitialized"
         })) == 2
     autoscaler.update()
     self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"})
示例#22
0
    def testTerminateOutdatedNodesGracefully(self):
        config = SMALL_CLUSTER.copy()
        config["min_workers"] = 5
        config["max_workers"] = 5
        config_path = self.write_config(config)
        self.provider = MockProvider()
        self.provider.create_node({}, {TAG_RAY_NODE_TYPE: "worker"}, 10)
        runner = MockProcessRunner()
        autoscaler = StandardAutoscaler(config_path,
                                        LoadMetrics(),
                                        max_failures=0,
                                        process_runner=runner,
                                        update_interval_s=0)
        self.waitForNodes(10)

        # Gradually scales down to meet target size, never going too low
        for _ in range(10):
            autoscaler.update()
            self.waitForNodes(5, comparison=self.assertLessEqual)
            self.waitForNodes(4, comparison=self.assertGreaterEqual)

        # Eventually reaches steady state
        self.waitForNodes(5)
 def testScaleUp(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     autoscaler = StandardAutoscaler(config_path, max_failures=0)
     self.assertEqual(len(self.provider.nodes({})), 0)
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)
示例#24
0
    def __init__(self, redis_address, autoscaling_config, redis_password=None):
        # Initialize the Redis clients.
        self.state = ray.experimental.state.GlobalState()
        redis_ip_address = get_ip_address(args.redis_address)
        redis_port = get_port(args.redis_address)
        self.state._initialize_global_state(
            redis_ip_address, redis_port, redis_password=redis_password)
        self.redis = ray.services.create_redis_client(
            redis_address, password=redis_password)
        # Setup subscriptions to the primary Redis server and the Redis shards.
        self.primary_subscribe_client = self.redis.pubsub(
            ignore_subscribe_messages=True)
        # Keep a mapping from local scheduler client ID to IP address to use
        # for updating the load metrics.
        self.local_scheduler_id_to_ip_map = {}
        self.load_metrics = LoadMetrics()
        if autoscaling_config:
            self.autoscaler = StandardAutoscaler(autoscaling_config,
                                                 self.load_metrics)
        else:
            self.autoscaler = None

        # Experimental feature: GCS flushing.
        self.issue_gcs_flushes = "RAY_USE_NEW_GCS" in os.environ
        self.gcs_flush_policy = None
        if self.issue_gcs_flushes:
            # Data is stored under the first data shard, so we issue flushes to
            # that redis server.
            addr_port = self.redis.lrange("RedisShards", 0, -1)
            if len(addr_port) > 1:
                logger.warning(
                    "Monitor: "
                    "TODO: if launching > 1 redis shard, flushing needs to "
                    "touch shards in parallel.")
                self.issue_gcs_flushes = False
            else:
                addr_port = addr_port[0].split(b":")
                self.redis_shard = redis.StrictRedis(
                    host=addr_port[0],
                    port=addr_port[1],
                    password=redis_password)
                try:
                    self.redis_shard.execute_command("HEAD.FLUSH 0")
                except redis.exceptions.ResponseError as e:
                    logger.info(
                        "Monitor: "
                        "Turning off flushing due to exception: {}".format(
                            str(e)))
                    self.issue_gcs_flushes = False
    def testDynamicScaling(self):
        config_path = self.write_config(SMALL_CLUSTER)
        self.provider = MockProvider()
        autoscaler = StandardAutoscaler(config_path,
                                        max_concurrent_launches=5,
                                        max_failures=0)
        self.assertEqual(len(self.provider.nodes({})), 0)
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 2)

        # Update the config to reduce the cluster size
        new_config = SMALL_CLUSTER.copy()
        new_config["max_workers"] = 1
        self.write_config(new_config)
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 1)

        # Update the config to reduce the cluster size
        new_config["max_workers"] = 10
        self.write_config(new_config)
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 6)
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 10)
示例#26
0
 def __init__(self, redis_address, autoscaling_config, redis_password=None):
     # Initialize the Redis clients.
     ray.state.state._initialize_global_state(redis_address,
                                              redis_password=redis_password)
     self.redis = ray.services.create_redis_client(redis_address,
                                                   password=redis_password)
     # Set the redis client and mode so _internal_kv works for autoscaler.
     worker = ray.worker.global_worker
     worker.redis_client = self.redis
     worker.mode = 0
     # Setup subscriptions to the primary Redis server and the Redis shards.
     self.primary_subscribe_client = self.redis.pubsub(
         ignore_subscribe_messages=True)
     # Keep a mapping from raylet client ID to IP address to use
     # for updating the load metrics.
     self.raylet_id_to_ip_map = {}
     self.load_metrics = LoadMetrics()
     if autoscaling_config:
         self.autoscaler = StandardAutoscaler(autoscaling_config,
                                              self.load_metrics)
         self.autoscaling_config = autoscaling_config
     else:
         self.autoscaler = None
         self.autoscaling_config = None
示例#27
0
 def __init__(self, redis_address, redis_port, autoscaling_config):
     # Initialize the Redis clients.
     self.state = ray.experimental.state.GlobalState()
     self.state._initialize_global_state(redis_address, redis_port)
     self.redis = redis.StrictRedis(
         host=redis_address, port=redis_port, db=0)
     # TODO(swang): Update pubsub client to use ray.experimental.state once
     # subscriptions are implemented there.
     self.subscribe_client = self.redis.pubsub()
     self.subscribed = {}
     # Initialize data structures to keep track of the active database
     # clients.
     self.dead_local_schedulers = set()
     self.live_plasma_managers = Counter()
     self.dead_plasma_managers = set()
     # Keep a mapping from local scheduler client ID to IP address to use
     # for updating the load metrics.
     self.local_scheduler_id_to_ip_map = dict()
     self.load_metrics = LoadMetrics()
     if autoscaling_config:
         self.autoscaler = StandardAutoscaler(
             autoscaling_config, self.load_metrics)
     else:
         self.autoscaler = None
示例#28
0
 def testScaleUp(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     autoscaler = StandardAutoscaler(
         config_path, LoadMetrics(), max_failures=0, update_interval_s=0)
     assert len(self.provider.nodes({})) == 0
     autoscaler.update()
     self.waitForNodes(2)
     autoscaler.update()
     self.waitForNodes(2)
示例#29
0
 def testMaxFailures(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     self.provider.throw = True
     autoscaler = StandardAutoscaler(
         config_path, LoadMetrics(), max_failures=2, update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     with pytest.raises(Exception):
         autoscaler.update()
 def testLaunchNewNodeOnOutOfBandTerminate(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     autoscaler = StandardAutoscaler(config_path, max_failures=0)
     autoscaler.update()
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)
     for node in self.provider.mock_nodes.values():
         node.state = "terminated"
     self.assertEqual(len(self.provider.nodes({})), 0)
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)
示例#31
0
 def testInitialWorkers(self):
     config = SMALL_CLUSTER.copy()
     config["min_workers"] = 0
     config["max_workers"] = 20
     config["initial_workers"] = 10
     config_path = self.write_config(config)
     self.provider = MockProvider()
     autoscaler = StandardAutoscaler(config_path,
                                     LoadMetrics(),
                                     max_launch_batch=5,
                                     max_concurrent_launches=5,
                                     max_failures=0,
                                     update_interval_s=0)
     self.waitForNodes(0)
     autoscaler.update()
     self.waitForNodes(5)  # expected due to batch sizes and concurrency
     autoscaler.update()
     self.waitForNodes(10)
     autoscaler.update()
    def testIgnoresCorruptedConfig(self):
        config_path = self.write_config(SMALL_CLUSTER)
        self.provider = MockProvider()
        autoscaler = StandardAutoscaler(config_path,
                                        max_concurrent_launches=10,
                                        max_failures=0)
        autoscaler.update()

        # Write a corrupted config
        self.write_config("asdf")
        for _ in range(10):
            autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 2)

        # New a good config again
        new_config = SMALL_CLUSTER.copy()
        new_config["max_workers"] = 10
        self.write_config(new_config)
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 10)
示例#33
0
 def testLaunchNewNodeOnOutOfBandTerminate(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     autoscaler = StandardAutoscaler(
         config_path, LoadMetrics(), max_failures=0, update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     self.waitForNodes(2)
     for node in self.provider.mock_nodes.values():
         node.state = "terminated"
     assert len(self.provider.nodes({})) == 0
     autoscaler.update()
     self.waitForNodes(2)
示例#34
0
 def testScaleUpMinSanity(self):
     config_path = self.write_config(MULTI_WORKER_CLUSTER)
     self.provider = MockProvider()
     runner = MockProcessRunner()
     autoscaler = StandardAutoscaler(config_path,
                                     LoadMetrics(),
                                     max_failures=0,
                                     process_runner=runner,
                                     update_interval_s=0)
     assert len(self.provider.non_terminated_nodes({})) == 0
     autoscaler.update()
     self.waitForNodes(2)
     autoscaler.update()
     self.waitForNodes(2)
示例#35
0
    def testRecoverUnhealthyWorkers(self):
        config_path = self.write_config(SMALL_CLUSTER)
        self.provider = MockProvider()
        runner = MockProcessRunner()
        lm = LoadMetrics()
        autoscaler = StandardAutoscaler(config_path,
                                        lm,
                                        max_failures=0,
                                        process_runner=runner,
                                        update_interval_s=0)
        autoscaler.update()
        self.waitForNodes(2)
        for node in self.provider.mock_nodes.values():
            node.state = "running"
        autoscaler.update()
        self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"})

        # Mark a node as unhealthy
        lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0
        num_calls = len(runner.calls)
        autoscaler.update()
        self.waitFor(lambda: len(runner.calls) > num_calls, num_retries=150)
示例#36
0
 def testUpdateThrottling(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     autoscaler = StandardAutoscaler(config_path,
                                     LoadMetrics(),
                                     max_concurrent_launches=5,
                                     max_failures=0,
                                     update_interval_s=10)
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)
     new_config = SMALL_CLUSTER.copy()
     new_config["max_workers"] = 1
     self.write_config(new_config)
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)  # not updated yet
示例#37
0
文件: monitor.py 项目: adgirish/ray
 def __init__(self, redis_address, redis_port, autoscaling_config):
     # Initialize the Redis clients.
     self.state = ray.experimental.state.GlobalState()
     self.state._initialize_global_state(redis_address, redis_port)
     self.redis = redis.StrictRedis(
         host=redis_address, port=redis_port, db=0)
     # TODO(swang): Update pubsub client to use ray.experimental.state once
     # subscriptions are implemented there.
     self.subscribe_client = self.redis.pubsub()
     self.subscribed = {}
     # Initialize data structures to keep track of the active database
     # clients.
     self.dead_local_schedulers = set()
     self.live_plasma_managers = Counter()
     self.dead_plasma_managers = set()
     self.load_metrics = LoadMetrics()
     if autoscaling_config:
         self.autoscaler = StandardAutoscaler(
             autoscaling_config, self.load_metrics)
     else:
         self.autoscaler = None
示例#38
0
 def testConfiguresNewNodes(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     runner = MockProcessRunner()
     autoscaler = StandardAutoscaler(config_path,
                                     LoadMetrics(),
                                     max_failures=0,
                                     process_runner=runner,
                                     update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     self.waitForNodes(2)
     self.provider.finish_starting_nodes()
     autoscaler.update()
     self.waitForNodes(2,
                       tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
示例#39
0
 def testConfiguresNewNodes(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     runner = MockProcessRunner()
     autoscaler = StandardAutoscaler(config_path,
                                     LoadMetrics(),
                                     max_failures=0,
                                     process_runner=runner,
                                     update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     self.waitForNodes(2)
     for node in self.provider.mock_nodes.values():
         node.state = "running"
     autoscaler.update()
     self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"})
示例#40
0
 def testConfiguresOutdatedNodes(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     runner = MockProcessRunner()
     autoscaler = StandardAutoscaler(config_path,
                                     LoadMetrics(),
                                     max_failures=0,
                                     process_runner=runner,
                                     update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     self.waitForNodes(2)
     self.provider.finish_starting_nodes()
     autoscaler.update()
     self.waitForNodes(2,
                       tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
     runner.calls = []
     new_config = SMALL_CLUSTER.copy()
     new_config["worker_setup_commands"] = ["cmdX", "cmdY"]
     self.write_config(new_config)
     autoscaler.update()
     autoscaler.update()
     self.waitFor(lambda: len(runner.calls) > 0)
示例#41
0
 def testInitialWorkers(self):
     config = SMALL_CLUSTER.copy()
     config["min_workers"] = 0
     config["max_workers"] = 20
     config["initial_workers"] = 10
     config_path = self.write_config(config)
     self.provider = MockProvider()
     autoscaler = StandardAutoscaler(
         config_path,
         LoadMetrics(),
         max_launch_batch=5,
         max_concurrent_launches=5,
         max_failures=0,
         update_interval_s=0)
     self.waitForNodes(0)
     autoscaler.update()
     self.waitForNodes(5)  # expected due to batch sizes and concurrency
     autoscaler.update()
     self.waitForNodes(10)
     autoscaler.update()
    def testLaunchConfigChange(self):
        config_path = self.write_config(SMALL_CLUSTER)
        self.provider = MockProvider()
        autoscaler = StandardAutoscaler(config_path, max_failures=0)
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 2)

        # Update the config to change the node type
        new_config = SMALL_CLUSTER.copy()
        new_config["worker_nodes"]["InstanceType"] = "updated"
        self.write_config(new_config)
        existing_nodes = set(self.provider.nodes({}))
        for _ in range(5):
            autoscaler.update()
        new_nodes = set(self.provider.nodes({}))
        self.assertEqual(len(new_nodes), 2)
        self.assertEqual(len(new_nodes.intersection(existing_nodes)), 0)
 def testConfiguresOutdatedNodes(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     runner = MockProcessRunner()
     autoscaler = StandardAutoscaler(config_path,
                                     max_failures=0,
                                     process_runner=runner,
                                     verbose_updates=True,
                                     node_updater_cls=NodeUpdaterThread)
     autoscaler.update()
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)
     for node in self.provider.mock_nodes.values():
         node.state = "running"
     autoscaler.update()
     self.waitFor(lambda: len(
         self.provider.nodes({TAG_RAY_NODE_STATUS: "Up-to-date"})) == 2)
     runner.calls = []
     new_config = SMALL_CLUSTER.copy()
     new_config["worker_init_commands"] = ["cmdX", "cmdY"]
     self.write_config(new_config)
     autoscaler.update()
     autoscaler.update()
     self.waitFor(lambda: len(runner.calls) > 0)
示例#44
0
    def testIgnoresCorruptedConfig(self):
        config_path = self.write_config(SMALL_CLUSTER)
        self.provider = MockProvider()
        autoscaler = StandardAutoscaler(
            config_path, LoadMetrics(), max_concurrent_launches=10,
            max_failures=0, update_interval_s=0)
        autoscaler.update()

        # Write a corrupted config
        self.write_config("asdf")
        for _ in range(10):
            autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 2)

        # New a good config again
        new_config = SMALL_CLUSTER.copy()
        new_config["min_workers"] = 10
        new_config["max_workers"] = 10
        self.write_config(new_config)
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 10)
示例#45
0
    def testLaunchConfigChange(self):
        config_path = self.write_config(SMALL_CLUSTER)
        self.provider = MockProvider()
        autoscaler = StandardAutoscaler(
            config_path, LoadMetrics(), max_failures=0, update_interval_s=0)
        autoscaler.update()
        self.waitForNodes(2)

        # Update the config to change the node type
        new_config = SMALL_CLUSTER.copy()
        new_config["worker_nodes"]["InstanceType"] = "updated"
        self.write_config(new_config)
        self.provider.ready_to_create.clear()
        for _ in range(5):
            autoscaler.update()
        self.waitForNodes(0)
        self.provider.ready_to_create.set()
        self.waitForNodes(2)
 def testReportsConfigFailures(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     runner = MockProcessRunner(fail_cmds=["cmd1"])
     autoscaler = StandardAutoscaler(config_path,
                                     max_failures=0,
                                     process_runner=runner,
                                     verbose_updates=True,
                                     node_updater_cls=NodeUpdaterThread)
     autoscaler.update()
     autoscaler.update()
     self.assertEqual(len(self.provider.nodes({})), 2)
     for node in self.provider.mock_nodes.values():
         node.state = "running"
     assert len(self.provider.nodes({TAG_RAY_NODE_STATUS:
                                     "Uninitialized"})) == 2
     autoscaler.update()
     self.waitFor(lambda: len(
         self.provider.nodes({TAG_RAY_NODE_STATUS: "UpdateFailed"})) == 2)
 def testUpdateConfig(self):
     config = MULTI_WORKER_CLUSTER.copy()
     config_path = self.write_config(config)
     self.provider = MockProvider()
     runner = MockProcessRunner()
     autoscaler = StandardAutoscaler(config_path,
                                     LoadMetrics(),
                                     max_failures=0,
                                     process_runner=runner,
                                     update_interval_s=0)
     assert len(self.provider.non_terminated_nodes({})) == 0
     autoscaler.update()
     self.waitForNodes(2)
     config["min_workers"] = 0
     config["available_node_types"]["m4.large"]["node_config"][
         "field_changed"] = 1
     config_path = self.write_config(config)
     autoscaler.update()
     self.waitForNodes(0)
示例#48
0
 def testReportsConfigFailures(self):
     config = copy.deepcopy(SMALL_CLUSTER)
     config["provider"]["type"] = "external"
     config = prepare_config(config)
     config["provider"]["type"] = "mock"
     config_path = self.write_config(config)
     self.provider = MockProvider()
     runner = MockProcessRunner(fail_cmds=["setup_cmd"])
     autoscaler = StandardAutoscaler(config_path,
                                     LoadMetrics(),
                                     max_failures=0,
                                     process_runner=runner,
                                     update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     self.waitForNodes(2)
     self.provider.finish_starting_nodes()
     autoscaler.update()
     self.waitForNodes(
         2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED})
示例#49
0
 def testConfiguresNewNodes(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     runner = MockProcessRunner()
     autoscaler = StandardAutoscaler(config_path,
                                     LoadMetrics(),
                                     max_failures=0,
                                     process_runner=runner,
                                     verbose_updates=True,
                                     node_updater_cls=NodeUpdaterThread,
                                     update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     self.waitForNodes(2)
     for node in self.provider.mock_nodes.values():
         node.state = "running"
     assert len(self.provider.nodes({TAG_RAY_NODE_STATUS:
                                     "uninitialized"})) == 2
     autoscaler.update()
     self.waitForNodes(2, tag_filters={TAG_RAY_NODE_STATUS: "up-to-date"})
示例#50
0
 def testReportsConfigFailures(self):
     config = copy.deepcopy(SMALL_CLUSTER)
     config["provider"]["type"] = "external"
     config = fillout_defaults(config)
     config["provider"]["type"] = "mock"
     config_path = self.write_config(config)
     self.provider = MockProvider()
     runner = MockProcessRunner(fail_cmds=["cmd1"])
     autoscaler = StandardAutoscaler(config_path,
                                     LoadMetrics(),
                                     max_failures=0,
                                     process_runner=runner,
                                     update_interval_s=0)
     autoscaler.update()
     autoscaler.update()
     self.waitForNodes(2)
     for node in self.provider.mock_nodes.values():
         node.state = "running"
     autoscaler.update()
     self.waitForNodes(2,
                       tag_filters={TAG_RAY_NODE_STATUS: "update-failed"})
示例#51
0
 def testUpdateThrottling(self):
     config_path = self.write_config(SMALL_CLUSTER)
     self.provider = MockProvider()
     autoscaler = StandardAutoscaler(config_path,
                                     LoadMetrics(),
                                     max_launch_batch=5,
                                     max_concurrent_launches=5,
                                     max_failures=0,
                                     update_interval_s=10)
     autoscaler.update()
     self.waitForNodes(2)
     assert autoscaler.num_launches_pending.value == 0
     new_config = SMALL_CLUSTER.copy()
     new_config["max_workers"] = 1
     self.write_config(new_config)
     autoscaler.update()
     # not updated yet
     # note that node termination happens in the main thread, so
     # we do not need to add any delay here before checking
     assert len(self.provider.non_terminated_nodes({})) == 2
     assert autoscaler.num_launches_pending.value == 0
示例#52
0
    def testScaleUpIgnoreUsed(self):
        config = MULTI_WORKER_CLUSTER.copy()
        # Commenting out this line causes the test case to fail?!?!
        config["min_workers"] = 0
        config["target_utilization_fraction"] = 1.0
        config_path = self.write_config(config)
        self.provider = MockProvider()
        self.provider.create_node({}, {
            TAG_RAY_NODE_KIND: "head",
            TAG_RAY_USER_NODE_TYPE: "p2.xlarge"
        }, 1)
        head_ip = self.provider.non_terminated_node_ips({})[0]
        self.provider.finish_starting_nodes()
        runner = MockProcessRunner()
        lm = LoadMetrics(local_ip=head_ip)
        autoscaler = StandardAutoscaler(
            config_path,
            lm,
            max_failures=0,
            process_runner=runner,
            update_interval_s=0)
        autoscaler.update()
        self.waitForNodes(1)
        lm.update(head_ip, {"CPU": 4, "GPU": 1}, {}, {})
        self.waitForNodes(1)

        lm.update(
            head_ip, {
                "CPU": 4,
                "GPU": 1
            }, {"GPU": 1}, {},
            waiting_bundles=[{
                "GPU": 1
            }])
        autoscaler.update()
        self.waitForNodes(2)
        assert self.provider.mock_nodes[1].node_type == "p2.xlarge"
示例#53
0
文件: monitor.py 项目: zerocurve/ray
class Monitor(object):
    """A monitor for Ray processes.

    The monitor is in charge of cleaning up the tables in the global state
    after processes have died. The monitor is currently not responsible for
    detecting component failures.

    Attributes:
        redis: A connection to the Redis server.
        use_raylet: A bool indicating whether to use the raylet code path or
            not.
        subscribe_client: A pubsub client for the Redis server. This is used to
            receive notifications about failed components.
        subscribed: A dictionary mapping channel names (str) to whether or not
            the subscription to that channel has succeeded yet (bool).
        dead_local_schedulers: A set of the local scheduler IDs of all of the
            local schedulers that were up at one point and have died since
            then.
        live_plasma_managers: A counter mapping live plasma manager IDs to the
            number of heartbeats that have passed since we last heard from that
            plasma manager. A plasma manager is live if we received a heartbeat
            from it at any point, and if it has not timed out.
        dead_plasma_managers: A set of the plasma manager IDs of all the plasma
            managers that were up at one point and have died since then.
    """
    def __init__(self, redis_address, redis_port, autoscaling_config):
        # Initialize the Redis clients.
        self.state = ray.experimental.state.GlobalState()
        self.state._initialize_global_state(redis_address, redis_port)
        self.use_raylet = self.state.use_raylet
        self.redis = redis.StrictRedis(host=redis_address,
                                       port=redis_port,
                                       db=0)
        # TODO(swang): Update pubsub client to use ray.experimental.state once
        # subscriptions are implemented there.
        self.subscribe_client = self.redis.pubsub()
        self.subscribed = {}
        # Initialize data structures to keep track of the active database
        # clients.
        self.dead_local_schedulers = set()
        self.live_plasma_managers = Counter()
        self.dead_plasma_managers = set()
        # Keep a mapping from local scheduler client ID to IP address to use
        # for updating the load metrics.
        self.local_scheduler_id_to_ip_map = {}
        self.load_metrics = LoadMetrics()
        if autoscaling_config:
            self.autoscaler = StandardAutoscaler(autoscaling_config,
                                                 self.load_metrics)
        else:
            self.autoscaler = None

    def subscribe(self, channel):
        """Subscribe to the given channel.

        Args:
            channel (str): The channel to subscribe to.

        Raises:
            Exception: An exception is raised if the subscription fails.
        """
        self.subscribe_client.subscribe(channel)
        self.subscribed[channel] = False

    def cleanup_task_table(self):
        """Clean up global state for failed local schedulers.

        This marks any tasks that were scheduled on dead local schedulers as
        TASK_STATUS_LOST. A local scheduler is deemed dead if it is in
        self.dead_local_schedulers.
        """
        tasks = self.state.task_table()
        num_tasks_updated = 0
        for task_id, task in tasks.items():
            # See if the corresponding local scheduler is alive.
            if task["LocalSchedulerID"] not in self.dead_local_schedulers:
                continue

            # Remove dummy objects returned by actor tasks from any plasma
            # manager. Although the objects may still exist in that object
            # store, this deletion makes them effectively unreachable by any
            # local scheduler connected to a different store.
            # TODO(swang): Actually remove the objects from the object store,
            # so that the reconstructed actor can reuse the same object store.
            if hex_to_binary(task["TaskSpec"]["ActorID"]) != NIL_ACTOR_ID:
                dummy_object_id = task["TaskSpec"]["ReturnObjectIDs"][-1]
                obj = self.state.object_table(dummy_object_id)
                manager_ids = obj["ManagerIDs"]
                if manager_ids is not None:
                    # The dummy object should exist on at most one plasma
                    # manager, the manager associated with the local scheduler
                    # that died.
                    assert len(manager_ids) <= 1
                    # Remove the dummy object from the plasma manager
                    # associated with the dead local scheduler, if any.
                    for manager in manager_ids:
                        ok = self.state._execute_command(
                            dummy_object_id, "RAY.OBJECT_TABLE_REMOVE",
                            dummy_object_id.id(), hex_to_binary(manager))
                        if ok != b"OK":
                            log.warn("Failed to remove object location for "
                                     "dead plasma manager.")

            # If the task is scheduled on a dead local scheduler, mark the
            # task as lost.
            key = binary_to_object_id(hex_to_binary(task_id))
            ok = self.state._execute_command(
                key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
                ray.experimental.state.TASK_STATUS_LOST, NIL_ID,
                task["ExecutionDependenciesString"], task["SpillbackCount"])
            if ok != b"OK":
                log.warn("Failed to update lost task for dead scheduler.")
            num_tasks_updated += 1

        if num_tasks_updated > 0:
            log.warn("Marked {} tasks as lost.".format(num_tasks_updated))

    def cleanup_object_table(self):
        """Clean up global state for failed plasma managers.

        This removes dead plasma managers from any location entries in the
        object table. A plasma manager is deemed dead if it is in
        self.dead_plasma_managers.
        """
        # TODO(swang): Also kill the associated plasma store, since it's no
        # longer reachable without a plasma manager.
        objects = self.state.object_table()
        num_objects_removed = 0
        for object_id, obj in objects.items():
            manager_ids = obj["ManagerIDs"]
            if manager_ids is None:
                continue
            for manager in manager_ids:
                if manager in self.dead_plasma_managers:
                    # If the object was on a dead plasma manager, remove that
                    # location entry.
                    ok = self.state._execute_command(
                        object_id, "RAY.OBJECT_TABLE_REMOVE", object_id.id(),
                        hex_to_binary(manager))
                    if ok != b"OK":
                        log.warn("Failed to remove object location for dead "
                                 "plasma manager.")
                    num_objects_removed += 1
        if num_objects_removed > 0:
            log.warn("Marked {} objects as lost.".format(num_objects_removed))

    def scan_db_client_table(self):
        """Scan the database client table for dead clients.

        After subscribing to the client table, it's necessary to call this
        before reading any messages from the subscription channel. This ensures
        that we do not miss any notifications for deleted clients that occurred
        before we subscribed.
        """
        # Exit if we are using the raylet code path because client_table is
        # implemented differently. TODO(rkn): Fix this.
        if self.use_raylet:
            return

        clients = self.state.client_table()
        for node_ip_address, node_clients in clients.items():
            for client in node_clients:
                db_client_id = client["DBClientID"]
                client_type = client["ClientType"]
                if client["Deleted"]:
                    if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
                        self.dead_local_schedulers.add(db_client_id)
                    elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
                        self.dead_plasma_managers.add(db_client_id)

    def subscribe_handler(self, channel, data):
        """Handle a subscription success message from Redis."""
        log.debug("Subscribed to {}, data was {}".format(channel, data))
        self.subscribed[channel] = True

    def db_client_notification_handler(self, unused_channel, data):
        """Handle a notification from the db_client table from Redis.

        This handler processes notifications from the db_client table.
        Notifications should be parsed using the SubscribeToDBClientTableReply
        flatbuffer. Deletions are processed, insertions are ignored. Cleanup of
        the associated state in the state tables should be handled by the
        caller.
        """
        notification_object = (SubscribeToDBClientTableReply.
                               GetRootAsSubscribeToDBClientTableReply(data, 0))
        db_client_id = binary_to_hex(notification_object.DbClientId())
        client_type = notification_object.ClientType()
        is_insertion = notification_object.IsInsertion()

        # If the update was an insertion, we ignore it.
        if is_insertion:
            return

        # If the update was a deletion, add them to our accounting for dead
        # local schedulers and plasma managers.
        log.warn("Removed {}, client ID {}".format(client_type, db_client_id))
        if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
            if db_client_id not in self.dead_local_schedulers:
                self.dead_local_schedulers.add(db_client_id)
        elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
            if db_client_id not in self.dead_plasma_managers:
                self.dead_plasma_managers.add(db_client_id)
            # Stop tracking this plasma manager's heartbeats, since it's
            # already dead.
            del self.live_plasma_managers[db_client_id]

    def local_scheduler_info_handler(self, unused_channel, data):
        """Handle a local scheduler heartbeat from Redis."""

        message = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage(
            data, 0)
        num_resources = message.DynamicResourcesLength()
        static_resources = {}
        dynamic_resources = {}
        for i in range(num_resources):
            dyn = message.DynamicResources(i)
            static = message.StaticResources(i)
            dynamic_resources[dyn.Key().decode("utf-8")] = dyn.Value()
            static_resources[static.Key().decode("utf-8")] = static.Value()

        # Update the load metrics for this local scheduler.
        client_id = binascii.hexlify(message.DbClientId()).decode("utf-8")
        ip = self.local_scheduler_id_to_ip_map.get(client_id)
        if ip:
            self.load_metrics.update(ip, static_resources, dynamic_resources)
        else:
            print(
                "Warning: could not find ip for client {}.".format(client_id))

    def xray_heartbeat_handler(self, unused_channel, data):
        """Handle an xray heartbeat message from Redis."""

        gcs_entries = GcsTableEntry.GetRootAsGcsTableEntry(data, 0)
        heartbeat_data = gcs_entries.Entries(0)
        message = HeartbeatTableData.GetRootAsHeartbeatTableData(
            heartbeat_data, 0)
        num_resources = message.ResourcesAvailableLabelLength()
        static_resources = {}
        dynamic_resources = {}
        for i in range(num_resources):
            dyn = message.ResourcesAvailableLabel(i)
            static = message.ResourcesTotalLabel(i)
            dynamic_resources[dyn] = message.ResourcesAvailableCapacity(i)
            static_resources[static] = message.ResourcesTotalCapacity(i)

        # Update the load metrics for this local scheduler.
        client_id = message.ClientId().decode("utf-8")
        ip = self.local_scheduler_id_to_ip_map.get(client_id)
        if ip:
            self.load_metrics.update(ip, static_resources, dynamic_resources)
        else:
            print(
                "Warning: could not find ip for client {}.".format(client_id))

    def plasma_manager_heartbeat_handler(self, unused_channel, data):
        """Handle a plasma manager heartbeat from Redis.

        This resets the number of heartbeats that we've missed from this plasma
        manager.
        """
        # The first DB_CLIENT_ID_SIZE characters are the client ID.
        db_client_id = data[:DB_CLIENT_ID_SIZE]
        # Reset the number of heartbeats that we've missed from this plasma
        # manager.
        self.live_plasma_managers[db_client_id] = 0

    def _entries_for_driver_in_shard(self, driver_id, redis_shard_index):
        """Collect IDs of control-state entries for a driver from a shard.

        Args:
            driver_id: The ID of the driver.
            redis_shard_index: The index of the Redis shard to query.

        Returns:
            Lists of IDs: (returned_object_ids, task_ids, put_objects). The
                first two are relevant to the driver and are safe to delete.
                The last contains all "put" objects in this redis shard; each
                element is an (object_id, corresponding task_id) pair.
        """
        # TODO(zongheng): consider adding save & restore functionalities.
        redis = self.state.redis_clients[redis_shard_index]
        task_table_infos = {}  # task id -> TaskInfo messages

        # Scan the task table & filter to get the list of tasks belong to this
        # driver.  Use a cursor in order not to block the redis shards.
        for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"):
            entry = redis.hgetall(key)
            task_info = TaskInfo.GetRootAsTaskInfo(entry[b"TaskSpec"], 0)
            if driver_id != task_info.DriverId():
                # Ignore tasks that aren't from this driver.
                continue
            task_table_infos[task_info.TaskId()] = task_info

        # Get the list of objects returned by these tasks.  Note these might
        # not belong to this redis shard.
        returned_object_ids = []
        for task_info in task_table_infos.values():
            returned_object_ids.extend([
                task_info.Returns(i) for i in range(task_info.ReturnsLength())
            ])

        # Also record all the ray.put()'d objects.
        put_objects = []
        for key in redis.scan_iter(match=OBJECT_INFO_PREFIX + b"*"):
            entry = redis.hgetall(key)
            if entry[b"is_put"] == "0":
                continue
            object_id = key.split(OBJECT_INFO_PREFIX)[1]
            task_id = entry[b"task"]
            put_objects.append((object_id, task_id))

        return returned_object_ids, task_table_infos.keys(), put_objects

    def _clean_up_entries_from_shard(self, object_ids, task_ids, shard_index):
        redis = self.state.redis_clients[shard_index]
        # Clean up (in the future, save) entries for non-empty objects.
        object_ids_locs = set()
        object_ids_infos = set()
        for object_id in object_ids:
            # OL.
            obj_loc = redis.zrange(OBJECT_LOCATION_PREFIX + object_id, 0, -1)
            if obj_loc:
                object_ids_locs.add(object_id)
            # OI.
            obj_info = redis.hgetall(OBJECT_INFO_PREFIX + object_id)
            if obj_info:
                object_ids_infos.add(object_id)

        # Form the redis keys to delete.
        keys = [TASK_TABLE_PREFIX + k for k in task_ids]
        keys.extend([OBJECT_LOCATION_PREFIX + k for k in object_ids_locs])
        keys.extend([OBJECT_INFO_PREFIX + k for k in object_ids_infos])

        if not keys:
            return
        # Remove with best effort.
        num_deleted = redis.delete(*keys)
        log.info(
            "Removed {} dead redis entries of the driver from redis shard {}.".
            format(num_deleted, shard_index))
        if num_deleted != len(keys):
            log.warning(
                "Failed to remove {} relevant redis entries"
                " from redis shard {}.".format(len(keys) - num_deleted))

    def _clean_up_entries_for_driver(self, driver_id):
        """Remove this driver's object/task entries from all redis shards.

        Specifically, removes control-state entries of:
            * all objects (OI and OL entries) created by `ray.put()` from the
              driver
            * all tasks belonging to the driver.
        """
        # TODO(zongheng): handle function_table, client_table, log_files --
        # these are in the metadata redis server, not in the shards.
        driver_object_ids = []
        driver_task_ids = []
        all_put_objects = []

        # Collect relevant ids.
        # TODO(zongheng): consider parallelizing this loop.
        for shard_index in range(len(self.state.redis_clients)):
            returned_object_ids, task_ids, put_objects = \
                self._entries_for_driver_in_shard(driver_id, shard_index)
            driver_object_ids.extend(returned_object_ids)
            driver_task_ids.extend(task_ids)
            all_put_objects.extend(put_objects)

        # For the put objects, keep those from relevant tasks.
        driver_task_ids_set = set(driver_task_ids)
        for object_id, task_id in all_put_objects:
            if task_id in driver_task_ids_set:
                driver_object_ids.append(object_id)

        # Partition IDs and distribute to shards.
        object_ids_per_shard = defaultdict(list)
        task_ids_per_shard = defaultdict(list)

        def ToShardIndex(index):
            return binary_to_object_id(index).redis_shard_hash() % len(
                self.state.redis_clients)

        for object_id in driver_object_ids:
            object_ids_per_shard[ToShardIndex(object_id)].append(object_id)
        for task_id in driver_task_ids:
            task_ids_per_shard[ToShardIndex(task_id)].append(task_id)

        # TODO(zongheng): consider parallelizing this loop.
        for shard_index in range(len(self.state.redis_clients)):
            self._clean_up_entries_from_shard(
                object_ids_per_shard[shard_index],
                task_ids_per_shard[shard_index], shard_index)

    def driver_removed_handler(self, unused_channel, data):
        """Handle a notification that a driver has been removed.

        This releases any GPU resources that were reserved for that driver in
        Redis.
        """
        message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
        driver_id = message.DriverId()
        log.info("Driver {} has been removed.".format(
            binary_to_hex(driver_id)))

        self._clean_up_entries_for_driver(driver_id)

    def process_messages(self, max_messages=10000):
        """Process all messages ready in the subscription channels.

        This reads messages from the subscription channels and calls the
        appropriate handlers until there are no messages left.

        Args:
            max_messages: The maximum number of messages to process before
                returning.
        """
        for _ in range(max_messages):
            message = self.subscribe_client.get_message()
            if message is None:
                return

            # Parse the message.
            channel = message["channel"]
            data = message["data"]

            # Determine the appropriate message handler.
            message_handler = None
            if not self.subscribed[channel]:
                # If the data was an integer, then the message was a response
                # to an initial subscription request.
                message_handler = self.subscribe_handler
            elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
                assert self.subscribed[channel]
                # The message was a heartbeat from a plasma manager.
                message_handler = self.plasma_manager_heartbeat_handler
            elif channel == LOCAL_SCHEDULER_INFO_CHANNEL:
                assert self.subscribed[channel]
                # The message was a heartbeat from a local scheduler
                message_handler = self.local_scheduler_info_handler
            elif channel == DB_CLIENT_TABLE_NAME:
                assert self.subscribed[channel]
                # The message was a notification from the db_client table.
                message_handler = self.db_client_notification_handler
            elif channel == DRIVER_DEATH_CHANNEL:
                assert self.subscribed[channel]
                # The message was a notification that a driver was removed.
                log.info("message-handler: driver_removed_handler")
                message_handler = self.driver_removed_handler
            elif channel == XRAY_HEARTBEAT_CHANNEL:
                # Similar functionality as local scheduler info channel
                message_handler = self.xray_heartbeat_handler
            else:
                raise Exception("This code should be unreachable.")

            # Call the handler.
            assert (message_handler is not None)
            message_handler(channel, data)

    def update_local_scheduler_map(self):
        if self.use_raylet:
            local_schedulers = self.state.client_table()
        else:
            local_schedulers = self.state.local_schedulers()
        self.local_scheduler_id_to_ip_map = {}
        for local_scheduler_info in local_schedulers:
            client_id = local_scheduler_info.get("DBClientID") or \
                local_scheduler_info["ClientID"]
            ip_address = (
                local_scheduler_info.get("AuxAddress")
                or local_scheduler_info["NodeManagerAddress"]).split(":")[0]
            self.local_scheduler_id_to_ip_map[client_id] = ip_address

    def run(self):
        """Run the monitor.

        This function loops forever, checking for messages about dead database
        clients and cleaning up state accordingly.
        """
        # Initialize the subscription channel.
        self.subscribe(DB_CLIENT_TABLE_NAME)
        self.subscribe(LOCAL_SCHEDULER_INFO_CHANNEL)
        self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
        self.subscribe(DRIVER_DEATH_CHANNEL)
        self.subscribe(XRAY_HEARTBEAT_CHANNEL)

        # Scan the database table for dead database clients. NOTE: This must be
        # called before reading any messages from the subscription channel.
        # This ensures that we start in a consistent state, since we may have
        # missed notifications that were sent before we connected to the
        # subscription channel.
        self.scan_db_client_table()
        # If there were any dead clients at startup, clean up the associated
        # state in the state tables.
        if len(self.dead_local_schedulers) > 0:
            self.cleanup_task_table()
        if len(self.dead_plasma_managers) > 0:
            self.cleanup_object_table()

        num_plasma_managers = len(self.live_plasma_managers) + len(
            self.dead_plasma_managers)

        log.debug("{} dead local schedulers, {} plasma managers total, {} "
                  "dead plasma managers".format(
                      len(self.dead_local_schedulers), num_plasma_managers,
                      len(self.dead_plasma_managers)))

        # Handle messages from the subscription channels.
        while True:
            # Update the mapping from local scheduler client ID to IP address.
            # This is only used to update the load metrics for the autoscaler.
            self.update_local_scheduler_map()

            # Process autoscaling actions
            if self.autoscaler:
                self.autoscaler.update()

            # Record how many dead local schedulers and plasma managers we had
            # at the beginning of this round.
            num_dead_local_schedulers = len(self.dead_local_schedulers)
            num_dead_plasma_managers = len(self.dead_plasma_managers)
            # Process a round of messages.
            self.process_messages()
            # If any new local schedulers or plasma managers were marked as
            # dead in this round, clean up the associated state.
            if len(self.dead_local_schedulers) > num_dead_local_schedulers:
                self.cleanup_task_table()
            if len(self.dead_plasma_managers) > num_dead_plasma_managers:
                self.cleanup_object_table()

            # Handle plasma managers that timed out during this round.
            plasma_manager_ids = list(self.live_plasma_managers.keys())
            for plasma_manager_id in plasma_manager_ids:
                if ((self.live_plasma_managers[plasma_manager_id]) >=
                        ray._config.num_heartbeats_timeout()):
                    log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE))
                    # Remove the plasma manager from the managers whose
                    # heartbeats we're tracking.
                    del self.live_plasma_managers[plasma_manager_id]
                    # Remove the plasma manager from the db_client table. The
                    # corresponding state in the object table will be cleaned
                    # up once we receive the notification for this db_client
                    # deletion.
                    self.redis.execute_command("RAY.DISCONNECT",
                                               plasma_manager_id)

            # Increment the number of heartbeats that we've missed from each
            # plasma manager.
            for plasma_manager_id in self.live_plasma_managers:
                self.live_plasma_managers[plasma_manager_id] += 1

            # Wait for a heartbeat interval before processing the next round of
            # messages.
            time.sleep(ray._config.heartbeat_timeout_milliseconds() * 1e-3)
示例#54
0
    def testScaleUpBasedOnLoad(self):
        config = SMALL_CLUSTER.copy()
        config["min_workers"] = 1
        config["max_workers"] = 10
        config["target_utilization_fraction"] = 0.5
        config_path = self.write_config(config)
        self.provider = MockProvider()
        lm = LoadMetrics()
        autoscaler = StandardAutoscaler(
            config_path, lm, max_failures=0, update_interval_s=0)
        assert len(self.provider.non_terminated_nodes({})) == 0
        autoscaler.update()
        self.waitForNodes(1)
        autoscaler.update()
        assert autoscaler.num_launches_pending.value == 0
        assert len(self.provider.non_terminated_nodes({})) == 1

        # Scales up as nodes are reported as used
        local_ip = services.get_node_ip_address()
        lm.update(local_ip, {"CPU": 2}, {"CPU": 0})  # head
        lm.update("172.0.0.0", {"CPU": 2}, {"CPU": 0})  # worker 1
        autoscaler.update()
        self.waitForNodes(3)
        lm.update("172.0.0.1", {"CPU": 2}, {"CPU": 0})
        autoscaler.update()
        self.waitForNodes(5)

        # Holds steady when load is removed
        lm.update("172.0.0.0", {"CPU": 2}, {"CPU": 2})
        lm.update("172.0.0.1", {"CPU": 2}, {"CPU": 2})
        autoscaler.update()
        assert autoscaler.num_launches_pending.value == 0
        assert len(self.provider.non_terminated_nodes({})) == 5

        # Scales down as nodes become unused
        lm.last_used_time_by_ip["172.0.0.0"] = 0
        lm.last_used_time_by_ip["172.0.0.1"] = 0
        autoscaler.update()
        assert autoscaler.num_launches_pending.value == 0
        assert len(self.provider.non_terminated_nodes({})) == 3
        lm.last_used_time_by_ip["172.0.0.2"] = 0
        lm.last_used_time_by_ip["172.0.0.3"] = 0
        autoscaler.update()
        assert autoscaler.num_launches_pending.value == 0
        assert len(self.provider.non_terminated_nodes({})) == 1
示例#55
0
文件: monitor.py 项目: adgirish/ray
class Monitor(object):
    """A monitor for Ray processes.

    The monitor is in charge of cleaning up the tables in the global state
    after processes have died. The monitor is currently not responsible for
    detecting component failures.

    Attributes:
        redis: A connection to the Redis server.
        subscribe_client: A pubsub client for the Redis server. This is used to
            receive notifications about failed components.
        subscribed: A dictionary mapping channel names (str) to whether or not
            the subscription to that channel has succeeded yet (bool).
        dead_local_schedulers: A set of the local scheduler IDs of all of the
            local schedulers that were up at one point and have died since
            then.
        live_plasma_managers: A counter mapping live plasma manager IDs to the
            number of heartbeats that have passed since we last heard from that
            plasma manager. A plasma manager is live if we received a heartbeat
            from it at any point, and if it has not timed out.
        dead_plasma_managers: A set of the plasma manager IDs of all the plasma
            managers that were up at one point and have died since then.
    """

    def __init__(self, redis_address, redis_port, autoscaling_config):
        # Initialize the Redis clients.
        self.state = ray.experimental.state.GlobalState()
        self.state._initialize_global_state(redis_address, redis_port)
        self.redis = redis.StrictRedis(
            host=redis_address, port=redis_port, db=0)
        # TODO(swang): Update pubsub client to use ray.experimental.state once
        # subscriptions are implemented there.
        self.subscribe_client = self.redis.pubsub()
        self.subscribed = {}
        # Initialize data structures to keep track of the active database
        # clients.
        self.dead_local_schedulers = set()
        self.live_plasma_managers = Counter()
        self.dead_plasma_managers = set()
        self.load_metrics = LoadMetrics()
        if autoscaling_config:
            self.autoscaler = StandardAutoscaler(
                autoscaling_config, self.load_metrics)
        else:
            self.autoscaler = None

    def subscribe(self, channel):
        """Subscribe to the given channel.

        Args:
            channel (str): The channel to subscribe to.

        Raises:
            Exception: An exception is raised if the subscription fails.
        """
        self.subscribe_client.subscribe(channel)
        self.subscribed[channel] = False

    def cleanup_actors(self):
        """Recreate any live actors whose corresponding local scheduler died.

        For any live actor whose local scheduler just died, we choose a new
        local scheduler and broadcast a notification to create that actor.
        """
        actor_info = self.state.actors()
        for actor_id, info in actor_info.items():
            if (not info["removed"] and
                    info["local_scheduler_id"] in self.dead_local_schedulers):
                # Choose a new local scheduler to run the actor.
                local_scheduler_id = ray.utils.select_local_scheduler(
                    info["driver_id"],
                    self.state.local_schedulers(), info["num_gpus"],
                    self.redis)
                import sys
                sys.stdout.flush()
                # The new local scheduler should not be the same as the old
                # local scheduler. TODO(rkn): This should not be an assert, it
                # should be something more benign.
                assert (binary_to_hex(local_scheduler_id) !=
                        info["local_scheduler_id"])
                # Announce to all of the local schedulers that the actor should
                # be recreated on this new local scheduler.
                ray.utils.publish_actor_creation(
                    hex_to_binary(actor_id),
                    hex_to_binary(info["driver_id"]), local_scheduler_id, True,
                    self.redis)
                log.info("Actor {} for driver {} was on dead local scheduler "
                         "{}. It is being recreated on local scheduler {}"
                         .format(actor_id, info["driver_id"],
                                 info["local_scheduler_id"],
                                 binary_to_hex(local_scheduler_id)))
                # Update the actor info in Redis.
                self.redis.hset(b"Actor:" + hex_to_binary(actor_id),
                                "local_scheduler_id", local_scheduler_id)

    def cleanup_task_table(self):
        """Clean up global state for failed local schedulers.

        This marks any tasks that were scheduled on dead local schedulers as
        TASK_STATUS_LOST. A local scheduler is deemed dead if it is in
        self.dead_local_schedulers.
        """
        tasks = self.state.task_table()
        num_tasks_updated = 0
        for task_id, task in tasks.items():
            # See if the corresponding local scheduler is alive.
            if task["LocalSchedulerID"] not in self.dead_local_schedulers:
                continue

            # Remove dummy objects returned by actor tasks from any plasma
            # manager. Although the objects may still exist in that object
            # store, this deletion makes them effectively unreachable by any
            # local scheduler connected to a different store.
            # TODO(swang): Actually remove the objects from the object store,
            # so that the reconstructed actor can reuse the same object store.
            if hex_to_binary(task["TaskSpec"]["ActorID"]) != NIL_ACTOR_ID:
                dummy_object_id = task["TaskSpec"]["ReturnObjectIDs"][-1]
                obj = self.state.object_table(dummy_object_id)
                manager_ids = obj["ManagerIDs"]
                if manager_ids is not None:
                    # The dummy object should exist on at most one plasma
                    # manager, the manager associated with the local scheduler
                    # that died.
                    assert len(manager_ids) <= 1
                    # Remove the dummy object from the plasma manager
                    # associated with the dead local scheduler, if any.
                    for manager in manager_ids:
                        ok = self.state._execute_command(
                            dummy_object_id, "RAY.OBJECT_TABLE_REMOVE",
                            dummy_object_id.id(), hex_to_binary(manager))
                        if ok != b"OK":
                            log.warn("Failed to remove object location for "
                                     "dead plasma manager.")

            # If the task is scheduled on a dead local scheduler, mark the
            # task as lost.
            key = binary_to_object_id(hex_to_binary(task_id))
            ok = self.state._execute_command(
                key, "RAY.TASK_TABLE_UPDATE",
                hex_to_binary(task_id),
                ray.experimental.state.TASK_STATUS_LOST, NIL_ID,
                task["ExecutionDependenciesString"],
                task["SpillbackCount"])
            if ok != b"OK":
                log.warn("Failed to update lost task for dead scheduler.")
            num_tasks_updated += 1

        if num_tasks_updated > 0:
            log.warn("Marked {} tasks as lost.".format(num_tasks_updated))

    def cleanup_object_table(self):
        """Clean up global state for failed plasma managers.

        This removes dead plasma managers from any location entries in the
        object table. A plasma manager is deemed dead if it is in
        self.dead_plasma_managers.
        """
        # TODO(swang): Also kill the associated plasma store, since it's no
        # longer reachable without a plasma manager.
        objects = self.state.object_table()
        num_objects_removed = 0
        for object_id, obj in objects.items():
            manager_ids = obj["ManagerIDs"]
            if manager_ids is None:
                continue
            for manager in manager_ids:
                if manager in self.dead_plasma_managers:
                    # If the object was on a dead plasma manager, remove that
                    # location entry.
                    ok = self.state._execute_command(object_id,
                                                     "RAY.OBJECT_TABLE_REMOVE",
                                                     object_id.id(),
                                                     hex_to_binary(manager))
                    if ok != b"OK":
                        log.warn("Failed to remove object location for dead "
                                 "plasma manager.")
                    num_objects_removed += 1
        if num_objects_removed > 0:
            log.warn("Marked {} objects as lost.".format(num_objects_removed))

    def scan_db_client_table(self):
        """Scan the database client table for dead clients.

        After subscribing to the client table, it's necessary to call this
        before reading any messages from the subscription channel. This ensures
        that we do not miss any notifications for deleted clients that occurred
        before we subscribed.
        """
        clients = self.state.client_table()
        for node_ip_address, node_clients in clients.items():
            for client in node_clients:
                db_client_id = client["DBClientID"]
                client_type = client["ClientType"]
                if client["Deleted"]:
                    if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
                        self.dead_local_schedulers.add(db_client_id)
                    elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
                        self.dead_plasma_managers.add(db_client_id)

    def subscribe_handler(self, channel, data):
        """Handle a subscription success message from Redis."""
        log.debug("Subscribed to {}, data was {}".format(channel, data))
        self.subscribed[channel] = True

    def db_client_notification_handler(self, unused_channel, data):
        """Handle a notification from the db_client table from Redis.

        This handler processes notifications from the db_client table.
        Notifications should be parsed using the SubscribeToDBClientTableReply
        flatbuffer. Deletions are processed, insertions are ignored. Cleanup of
        the associated state in the state tables should be handled by the
        caller.
        """
        notification_object = (SubscribeToDBClientTableReply.
                               GetRootAsSubscribeToDBClientTableReply(data, 0))
        db_client_id = binary_to_hex(notification_object.DbClientId())
        client_type = notification_object.ClientType()
        is_insertion = notification_object.IsInsertion()

        # If the update was an insertion, we ignore it.
        if is_insertion:
            return

        # If the update was a deletion, add them to our accounting for dead
        # local schedulers and plasma managers.
        log.warn("Removed {}, client ID {}".format(client_type, db_client_id))
        if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
            if db_client_id not in self.dead_local_schedulers:
                self.dead_local_schedulers.add(db_client_id)
        elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
            if db_client_id not in self.dead_plasma_managers:
                self.dead_plasma_managers.add(db_client_id)
            # Stop tracking this plasma manager's heartbeats, since it's
            # already dead.
            del self.live_plasma_managers[db_client_id]

    def local_scheduler_info_handler(self, unused_channel, data):
        """Handle a local scheduler heartbeat from Redis."""

        message = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage(
            data, 0)
        num_resources = message.DynamicResourcesLength()
        static_resources = {}
        dynamic_resources = {}
        for i in range(num_resources):
            dyn = message.DynamicResources(i)
            static = message.StaticResources(i)
            dynamic_resources[dyn.Key().decode("utf-8")] = dyn.Value()
            static_resources[static.Key().decode("utf-8")] = static.Value()
        client_id = binascii.hexlify(message.DbClientId()).decode("utf-8")
        clients = ray.global_state.client_table()
        local_schedulers = [
            entry for client in clients.values() for entry in client
            if (entry["ClientType"] == "local_scheduler" and not
                entry["Deleted"])
        ]
        ip = None
        for ls in local_schedulers:
            if ls["DBClientID"] == client_id:
                ip = ls["AuxAddress"].split(":")[0]
        if ip:
            self.load_metrics.update(ip, static_resources, dynamic_resources)
        else:
            print("Warning: could not find ip for client {} in {}".format(
                client_id, local_schedulers))

    def plasma_manager_heartbeat_handler(self, unused_channel, data):
        """Handle a plasma manager heartbeat from Redis.

        This resets the number of heartbeats that we've missed from this plasma
        manager.
        """
        # The first DB_CLIENT_ID_SIZE characters are the client ID.
        db_client_id = data[:DB_CLIENT_ID_SIZE]
        # Reset the number of heartbeats that we've missed from this plasma
        # manager.
        self.live_plasma_managers[db_client_id] = 0

    def _entries_for_driver_in_shard(self, driver_id, redis_shard_index):
        """Collect IDs of control-state entries for a driver from a shard.

        Args:
            driver_id: The ID of the driver.
            redis_shard_index: The index of the Redis shard to query.

        Returns:
            Lists of IDs: (returned_object_ids, task_ids, put_objects). The
                first two are relevant to the driver and are safe to delete.
                The last contains all "put" objects in this redis shard; each
                element is an (object_id, corresponding task_id) pair.
        """
        # TODO(zongheng): consider adding save & restore functionalities.
        redis = self.state.redis_clients[redis_shard_index]
        task_table_infos = {}  # task id -> TaskInfo messages

        # Scan the task table & filter to get the list of tasks belong to this
        # driver.  Use a cursor in order not to block the redis shards.
        for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"):
            entry = redis.hgetall(key)
            task_info = TaskInfo.GetRootAsTaskInfo(entry[b"TaskSpec"], 0)
            if driver_id != task_info.DriverId():
                # Ignore tasks that aren't from this driver.
                continue
            task_table_infos[task_info.TaskId()] = task_info

        # Get the list of objects returned by these tasks.  Note these might
        # not belong to this redis shard.
        returned_object_ids = []
        for task_info in task_table_infos.values():
            returned_object_ids.extend([
                task_info.Returns(i) for i in range(task_info.ReturnsLength())
            ])

        # Also record all the ray.put()'d objects.
        put_objects = []
        for key in redis.scan_iter(match=OBJECT_INFO_PREFIX + b"*"):
            entry = redis.hgetall(key)
            if entry[b"is_put"] == "0":
                continue
            object_id = key.split(OBJECT_INFO_PREFIX)[1]
            task_id = entry[b"task"]
            put_objects.append((object_id, task_id))

        return returned_object_ids, task_table_infos.keys(), put_objects

    def _clean_up_entries_from_shard(self, object_ids, task_ids, shard_index):
        redis = self.state.redis_clients[shard_index]
        # Clean up (in the future, save) entries for non-empty objects.
        object_ids_locs = set()
        object_ids_infos = set()
        for object_id in object_ids:
            # OL.
            obj_loc = redis.zrange(OBJECT_LOCATION_PREFIX + object_id, 0, -1)
            if obj_loc:
                object_ids_locs.add(object_id)
            # OI.
            obj_info = redis.hgetall(OBJECT_INFO_PREFIX + object_id)
            if obj_info:
                object_ids_infos.add(object_id)

        # Form the redis keys to delete.
        keys = [TASK_TABLE_PREFIX + k for k in task_ids]
        keys.extend([OBJECT_LOCATION_PREFIX + k for k in object_ids_locs])
        keys.extend([OBJECT_INFO_PREFIX + k for k in object_ids_infos])

        if not keys:
            return
        # Remove with best effort.
        num_deleted = redis.delete(*keys)
        log.info(
            "Removed {} dead redis entries of the driver from redis shard {}.".
            format(num_deleted, shard_index))
        if num_deleted != len(keys):
            log.warning(
                "Failed to remove {} relevant redis entries"
                " from redis shard {}.".format(len(keys) - num_deleted))

    def _clean_up_entries_for_driver(self, driver_id):
        """Remove this driver's object/task entries from all redis shards.

        Specifically, removes control-state entries of:
            * all objects (OI and OL entries) created by `ray.put()` from the
              driver
            * all tasks belonging to the driver.
        """
        # TODO(zongheng): handle function_table, client_table, log_files --
        # these are in the metadata redis server, not in the shards.
        driver_object_ids = []
        driver_task_ids = []
        all_put_objects = []

        # Collect relevant ids.
        # TODO(zongheng): consider parallelizing this loop.
        for shard_index in range(len(self.state.redis_clients)):
            returned_object_ids, task_ids, put_objects = \
                self._entries_for_driver_in_shard(driver_id, shard_index)
            driver_object_ids.extend(returned_object_ids)
            driver_task_ids.extend(task_ids)
            all_put_objects.extend(put_objects)

        # For the put objects, keep those from relevant tasks.
        driver_task_ids_set = set(driver_task_ids)
        for object_id, task_id in all_put_objects:
            if task_id in driver_task_ids_set:
                driver_object_ids.append(object_id)

        # Partition IDs and distribute to shards.
        object_ids_per_shard = defaultdict(list)
        task_ids_per_shard = defaultdict(list)

        def ToShardIndex(index):
            return binary_to_object_id(index).redis_shard_hash() % len(
                self.state.redis_clients)

        for object_id in driver_object_ids:
            object_ids_per_shard[ToShardIndex(object_id)].append(object_id)
        for task_id in driver_task_ids:
            task_ids_per_shard[ToShardIndex(task_id)].append(task_id)

        # TODO(zongheng): consider parallelizing this loop.
        for shard_index in range(len(self.state.redis_clients)):
            self._clean_up_entries_from_shard(
                object_ids_per_shard[shard_index],
                task_ids_per_shard[shard_index], shard_index)

    def driver_removed_handler(self, unused_channel, data):
        """Handle a notification that a driver has been removed.

        This releases any GPU resources that were reserved for that driver in
        Redis.
        """
        message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
        driver_id = message.DriverId()
        log.info(
            "Driver {} has been removed.".format(binary_to_hex(driver_id)))

        # Get a list of the local schedulers that have not been deleted.
        local_schedulers = ray.global_state.local_schedulers()

        self._clean_up_entries_for_driver(driver_id)

        # Release any GPU resources that have been reserved for this driver in
        # Redis.
        for local_scheduler in local_schedulers:
            if local_scheduler.get("GPU", 0) > 0:
                local_scheduler_id = local_scheduler["DBClientID"]

                num_gpus_returned = 0

                # Perform a transaction to return the GPUs.
                with self.redis.pipeline() as pipe:
                    while True:
                        try:
                            # If this key is changed before the transaction
                            # below (the multi/exec block), then the
                            # transaction will not take place.
                            pipe.watch(local_scheduler_id)

                            result = pipe.hget(local_scheduler_id,
                                               "gpus_in_use")
                            gpus_in_use = (dict() if result is None else
                                           json.loads(result.decode("ascii")))

                            driver_id_hex = binary_to_hex(driver_id)
                            if driver_id_hex in gpus_in_use:
                                num_gpus_returned = gpus_in_use.pop(
                                    driver_id_hex)

                            pipe.multi()

                            pipe.hset(local_scheduler_id, "gpus_in_use",
                                      json.dumps(gpus_in_use))

                            pipe.execute()
                            # If a WatchError is not raise, then the operations
                            # should have gone through atomically.
                            break
                        except redis.WatchError:
                            # Another client must have changed the watched key
                            # between the time we started WATCHing it and the
                            # pipeline's execution. We should just retry.
                            continue

                log.info("Driver {} is returning GPU IDs {} to local "
                         "scheduler {}.".format(
                             binary_to_hex(driver_id), num_gpus_returned,
                             local_scheduler_id))

    def process_messages(self):
        """Process all messages ready in the subscription channels.

        This reads messages from the subscription channels and calls the
        appropriate handlers until there are no messages left.
        """
        while True:
            message = self.subscribe_client.get_message()
            if message is None:
                return

            # Parse the message.
            channel = message["channel"]
            data = message["data"]

            # Determine the appropriate message handler.
            message_handler = None
            if not self.subscribed[channel]:
                # If the data was an integer, then the message was a response
                # to an initial subscription request.
                message_handler = self.subscribe_handler
            elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
                assert self.subscribed[channel]
                # The message was a heartbeat from a plasma manager.
                message_handler = self.plasma_manager_heartbeat_handler
            elif channel == LOCAL_SCHEDULER_INFO_CHANNEL:
                assert self.subscribed[channel]
                # The message was a heartbeat from a local scheduler
                message_handler = self.local_scheduler_info_handler
            elif channel == DB_CLIENT_TABLE_NAME:
                assert self.subscribed[channel]
                # The message was a notification from the db_client table.
                message_handler = self.db_client_notification_handler
            elif channel == DRIVER_DEATH_CHANNEL:
                assert self.subscribed[channel]
                # The message was a notification that a driver was removed.
                log.info("message-handler: driver_removed_handler")
                message_handler = self.driver_removed_handler
            else:
                raise Exception("This code should be unreachable.")

            # Call the handler.
            assert (message_handler is not None)
            message_handler(channel, data)

    def run(self):
        """Run the monitor.

        This function loops forever, checking for messages about dead database
        clients and cleaning up state accordingly.
        """
        # Initialize the subscription channel.
        self.subscribe(DB_CLIENT_TABLE_NAME)
        self.subscribe(LOCAL_SCHEDULER_INFO_CHANNEL)
        self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
        self.subscribe(DRIVER_DEATH_CHANNEL)

        # Scan the database table for dead database clients. NOTE: This must be
        # called before reading any messages from the subscription channel.
        # This ensures that we start in a consistent state, since we may have
        # missed notifications that were sent before we connected to the
        # subscription channel.
        self.scan_db_client_table()
        # If there were any dead clients at startup, clean up the associated
        # state in the state tables.
        if len(self.dead_local_schedulers) > 0:
            self.cleanup_task_table()
            self.cleanup_actors()
        if len(self.dead_plasma_managers) > 0:
            self.cleanup_object_table()
        log.debug("{} dead local schedulers, {} plasma managers total, {} "
                  "dead plasma managers".format(
                      len(self.dead_local_schedulers),
                      (len(self.live_plasma_managers) +
                       len(self.dead_plasma_managers)),
                      len(self.dead_plasma_managers)))

        # Handle messages from the subscription channels.
        while True:
            # Process autoscaling actions
            if self.autoscaler:
                self.autoscaler.update()
            # Record how many dead local schedulers and plasma managers we had
            # at the beginning of this round.
            num_dead_local_schedulers = len(self.dead_local_schedulers)
            num_dead_plasma_managers = len(self.dead_plasma_managers)
            # Process a round of messages.
            self.process_messages()
            # If any new local schedulers or plasma managers were marked as
            # dead in this round, clean up the associated state.
            if len(self.dead_local_schedulers) > num_dead_local_schedulers:
                self.cleanup_task_table()
                self.cleanup_actors()
            if len(self.dead_plasma_managers) > num_dead_plasma_managers:
                self.cleanup_object_table()

            # Handle plasma managers that timed out during this round.
            plasma_manager_ids = list(self.live_plasma_managers.keys())
            for plasma_manager_id in plasma_manager_ids:
                if ((self.live_plasma_managers[plasma_manager_id]) >=
                        ray._config.num_heartbeats_timeout()):
                    log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE))
                    # Remove the plasma manager from the managers whose
                    # heartbeats we're tracking.
                    del self.live_plasma_managers[plasma_manager_id]
                    # Remove the plasma manager from the db_client table. The
                    # corresponding state in the object table will be cleaned
                    # up once we receive the notification for this db_client
                    # deletion.
                    self.redis.execute_command("RAY.DISCONNECT",
                                               plasma_manager_id)

            # Increment the number of heartbeats that we've missed from each
            # plasma manager.
            for plasma_manager_id in self.live_plasma_managers:
                self.live_plasma_managers[plasma_manager_id] += 1

            # Wait for a heartbeat interval before processing the next round of
            # messages.
            time.sleep(ray._config.heartbeat_timeout_milliseconds() * 1e-3)
示例#56
0
    def testScaleUpBasedOnLoad(self):
        config = SMALL_CLUSTER.copy()
        config["min_workers"] = 2
        config["max_workers"] = 10
        config["target_utilization_fraction"] = 0.5
        config_path = self.write_config(config)
        self.provider = MockProvider()
        lm = LoadMetrics()
        autoscaler = StandardAutoscaler(
            config_path, lm, max_failures=0, update_interval_s=0)
        self.assertEqual(len(self.provider.nodes({})), 0)
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 2)
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 2)

        # Scales up as nodes are reported as used
        lm.update("172.0.0.0", {"CPU": 2}, {"CPU": 0})
        lm.update("172.0.0.1", {"CPU": 2}, {"CPU": 0})
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 4)
        lm.update("172.0.0.2", {"CPU": 2}, {"CPU": 0})
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 6)

        # Holds steady when load is removed
        lm.update("172.0.0.0", {"CPU": 2}, {"CPU": 2})
        lm.update("172.0.0.1", {"CPU": 2}, {"CPU": 2})
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 6)

        # Scales down as nodes become unused
        lm.last_used_time_by_ip["172.0.0.0"] = 0
        lm.last_used_time_by_ip["172.0.0.1"] = 0
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 4)
        lm.last_used_time_by_ip["172.0.0.2"] = 0
        lm.last_used_time_by_ip["172.0.0.3"] = 0
        autoscaler.update()
        self.assertEqual(len(self.provider.nodes({})), 2)
示例#57
0
class Monitor(object):
    """A monitor for Ray processes.

    The monitor is in charge of cleaning up the tables in the global state
    after processes have died. The monitor is currently not responsible for
    detecting component failures.

    Attributes:
        redis: A connection to the Redis server.
        subscribe_client: A pubsub client for the Redis server. This is used to
            receive notifications about failed components.
    """

    def __init__(self,
                 redis_address,
                 redis_port,
                 autoscaling_config,
                 redis_password=None):
        # Initialize the Redis clients.
        self.state = ray.experimental.state.GlobalState()
        self.state._initialize_global_state(
            redis_address, redis_port, redis_password=redis_password)
        self.redis = redis.StrictRedis(
            host=redis_address, port=redis_port, db=0, password=redis_password)
        # Setup subscriptions to the primary Redis server and the Redis shards.
        self.primary_subscribe_client = self.redis.pubsub(
            ignore_subscribe_messages=True)
        # Keep a mapping from local scheduler client ID to IP address to use
        # for updating the load metrics.
        self.local_scheduler_id_to_ip_map = {}
        self.load_metrics = LoadMetrics()
        if autoscaling_config:
            self.autoscaler = StandardAutoscaler(autoscaling_config,
                                                 self.load_metrics)
        else:
            self.autoscaler = None

        # Experimental feature: GCS flushing.
        self.issue_gcs_flushes = "RAY_USE_NEW_GCS" in os.environ
        self.gcs_flush_policy = None
        if self.issue_gcs_flushes:
            # Data is stored under the first data shard, so we issue flushes to
            # that redis server.
            addr_port = self.redis.lrange("RedisShards", 0, -1)
            if len(addr_port) > 1:
                logger.warning("TODO: if launching > 1 redis shard, flushing "
                               "needs to touch shards in parallel.")
                self.issue_gcs_flushes = False
            else:
                addr_port = addr_port[0].split(b":")
                self.redis_shard = redis.StrictRedis(
                    host=addr_port[0],
                    port=addr_port[1],
                    password=redis_password)
                try:
                    self.redis_shard.execute_command("HEAD.FLUSH 0")
                except redis.exceptions.ResponseError as e:
                    logger.info(
                        "Turning off flushing due to exception: {}".format(
                            str(e)))
                    self.issue_gcs_flushes = False

    def subscribe(self, channel):
        """Subscribe to the given channel on the primary Redis shard.

        Args:
            channel (str): The channel to subscribe to.

        Raises:
            Exception: An exception is raised if the subscription fails.
        """
        self.primary_subscribe_client.subscribe(channel)

    def xray_heartbeat_batch_handler(self, unused_channel, data):
        """Handle an xray heartbeat batch message from Redis."""

        gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
            data, 0)
        heartbeat_data = gcs_entries.Entries(0)

        message = (ray.gcs_utils.HeartbeatBatchTableData.
                   GetRootAsHeartbeatBatchTableData(heartbeat_data, 0))

        for j in range(message.BatchLength()):
            heartbeat_message = message.Batch(j)

            num_resources = heartbeat_message.ResourcesAvailableLabelLength()
            static_resources = {}
            dynamic_resources = {}
            for i in range(num_resources):
                dyn = heartbeat_message.ResourcesAvailableLabel(i)
                static = heartbeat_message.ResourcesTotalLabel(i)
                dynamic_resources[dyn] = (
                    heartbeat_message.ResourcesAvailableCapacity(i))
                static_resources[static] = (
                    heartbeat_message.ResourcesTotalCapacity(i))

            # Update the load metrics for this local scheduler.
            client_id = ray.utils.binary_to_hex(heartbeat_message.ClientId())
            ip = self.local_scheduler_id_to_ip_map.get(client_id)
            if ip:
                self.load_metrics.update(ip, static_resources,
                                         dynamic_resources)
            else:
                print("Warning: could not find ip for client {} in {}.".format(
                    client_id, self.local_scheduler_id_to_ip_map))

    def _xray_clean_up_entries_for_driver(self, driver_id):
        """Remove this driver's object/task entries from redis.

        Removes control-state entries of all tasks and task return
        objects belonging to the driver.

        Args:
            driver_id: The driver id.
        """

        xray_task_table_prefix = (
            ray.gcs_utils.TablePrefix_RAYLET_TASK_string.encode("ascii"))
        xray_object_table_prefix = (
            ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii"))

        task_table_objects = self.state.task_table()
        driver_id_hex = binary_to_hex(driver_id)
        driver_task_id_bins = set()
        for task_id_hex, task_info in task_table_objects.items():
            task_table_object = task_info["TaskSpec"]
            task_driver_id_hex = task_table_object["DriverID"]
            if driver_id_hex != task_driver_id_hex:
                # Ignore tasks that aren't from this driver.
                continue
            driver_task_id_bins.add(hex_to_binary(task_id_hex))

        # Get objects associated with the driver.
        object_table_objects = self.state.object_table()
        driver_object_id_bins = set()
        for object_id, _ in object_table_objects.items():
            task_id_bin = ray.raylet.compute_task_id(object_id).id()
            if task_id_bin in driver_task_id_bins:
                driver_object_id_bins.add(object_id.id())

        def to_shard_index(id_bin):
            return binary_to_object_id(id_bin).redis_shard_hash() % len(
                self.state.redis_clients)

        # Form the redis keys to delete.
        sharded_keys = [[] for _ in range(len(self.state.redis_clients))]
        for task_id_bin in driver_task_id_bins:
            sharded_keys[to_shard_index(task_id_bin)].append(
                xray_task_table_prefix + task_id_bin)
        for object_id_bin in driver_object_id_bins:
            sharded_keys[to_shard_index(object_id_bin)].append(
                xray_object_table_prefix + object_id_bin)

        # Remove with best effort.
        for shard_index in range(len(sharded_keys)):
            keys = sharded_keys[shard_index]
            if len(keys) == 0:
                continue
            redis = self.state.redis_clients[shard_index]
            num_deleted = redis.delete(*keys)
            logger.info("Removed {} dead redis entries of the driver from"
                        " redis shard {}.".format(num_deleted, shard_index))
            if num_deleted != len(keys):
                logger.warning("Failed to remove {} relevant redis entries"
                               " from redis shard {}.".format(
                                   len(keys) - num_deleted, shard_index))

    def xray_driver_removed_handler(self, unused_channel, data):
        """Handle a notification that a driver has been removed.

        Args:
            unused_channel: The message channel.
            data: The message data.
        """
        gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
            data, 0)
        driver_data = gcs_entries.Entries(0)
        message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData(
            driver_data, 0)
        driver_id = message.DriverId()
        logger.info("XRay Driver {} has been removed.".format(
            binary_to_hex(driver_id)))
        self._xray_clean_up_entries_for_driver(driver_id)

    def process_messages(self, max_messages=10000):
        """Process all messages ready in the subscription channels.

        This reads messages from the subscription channels and calls the
        appropriate handlers until there are no messages left.

        Args:
            max_messages: The maximum number of messages to process before
                returning.
        """
        subscribe_clients = [self.primary_subscribe_client]
        for subscribe_client in subscribe_clients:
            for _ in range(max_messages):
                message = subscribe_client.get_message()
                if message is None:
                    # Continue on to the next subscribe client.
                    break

                # Parse the message.
                channel = message["channel"]
                data = message["data"]

                # Determine the appropriate message handler.
                message_handler = None
                if channel == ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL:
                    # Similar functionality as local scheduler info channel
                    message_handler = self.xray_heartbeat_batch_handler
                elif channel == ray.gcs_utils.XRAY_DRIVER_CHANNEL:
                    # Handles driver death.
                    message_handler = self.xray_driver_removed_handler
                else:
                    raise Exception("This code should be unreachable.")

                # Call the handler.
                assert (message_handler is not None)
                message_handler(channel, data)

    def update_local_scheduler_map(self):
        local_schedulers = self.state.client_table()
        self.local_scheduler_id_to_ip_map = {}
        for local_scheduler_info in local_schedulers:
            client_id = local_scheduler_info.get("DBClientID") or \
                local_scheduler_info["ClientID"]
            ip_address = (
                local_scheduler_info.get("AuxAddress")
                or local_scheduler_info["NodeManagerAddress"]).split(":")[0]
            self.local_scheduler_id_to_ip_map[client_id] = ip_address

    def _maybe_flush_gcs(self):
        """Experimental: issue a flush request to the GCS.

        The purpose of this feature is to control GCS memory usage.

        To activate this feature, Ray must be compiled with the flag
        RAY_USE_NEW_GCS set, and Ray must be started at run time with the flag
        as well.
        """
        if not self.issue_gcs_flushes:
            return
        if self.gcs_flush_policy is None:
            serialized = self.redis.get("gcs_flushing_policy")
            if serialized is None:
                # Client has not set any policy; by default flushing is off.
                return
            self.gcs_flush_policy = pickle.loads(serialized)

        if not self.gcs_flush_policy.should_flush(self.redis_shard):
            return

        max_entries_to_flush = self.gcs_flush_policy.num_entries_to_flush()
        num_flushed = self.redis_shard.execute_command(
            "HEAD.FLUSH {}".format(max_entries_to_flush))
        logger.info("num_flushed {}".format(num_flushed))

        # This flushes event log and log files.
        ray.experimental.flush_redis_unsafe(self.redis)

        self.gcs_flush_policy.record_flush()

    def run(self):
        """Run the monitor.

        This function loops forever, checking for messages about dead database
        clients and cleaning up state accordingly.
        """
        # Initialize the subscription channel.
        self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL)
        self.subscribe(ray.gcs_utils.XRAY_DRIVER_CHANNEL)

        # TODO(rkn): If there were any dead clients at startup, we should clean
        # up the associated state in the state tables.

        # Handle messages from the subscription channels.
        while True:
            # Update the mapping from local scheduler client ID to IP address.
            # This is only used to update the load metrics for the autoscaler.
            self.update_local_scheduler_map()

            # Process autoscaling actions
            if self.autoscaler:
                self.autoscaler.update()

            self._maybe_flush_gcs()

            # Process a round of messages.
            self.process_messages()

            # Wait for a heartbeat interval before processing the next round of
            # messages.
            time.sleep(ray._config.heartbeat_timeout_milliseconds() * 1e-3)