async def test_adaptdl_job_register_checkpoint(ray_fix): job = RayAdaptDLJob(None, 0, 0) checkpoint = "foo" assert not job._checkpoint_received.is_set() job.register_checkpoint(checkpoint) assert job._checkpoint == "foo" assert job._checkpoint_received.is_set() assert ray.get(job._checkpoint_ref) == "foo"
async def test_fetch_metrics(): hints = { "gradParams": {"norm": 3.0, "var": 4.0}, "perfParams": { 'alpha_c': 0, 'beta_c': 0, 'alpha_n': 0, 'beta_n': 0, 'alpha_r': 0, 'beta_r': 0, 'gamma': 0}} job = RayAdaptDLJob(None, 0, 0) job._last_metrics = hints job._fetch_metrics()
async def test_adaptdl_job_update_workers(): job = RayAdaptDLJob({"CPU": 2}, 0, 3) allocation = [ "adaptdl_virtual_node_0", "adaptdl_virtual_node_1", "adaptdl_virtual_node_2" ] async def mocked_force_worker_checkpoint(): print("forcing checkpoint") job._checkpoint_received = True job._checkpoint = 3 async def mocked_create_workers(allocation): job._workers_created = True job._worker_ips = allocation job.force_worker_checkpoint = mocked_force_worker_checkpoint job._workers_created = False job._worker_ips = None job._create_workers = mocked_create_workers await job.update_workers(allocation) assert job._checkpoint == 3 assert job._workers_created assert job._worker_ips == allocation job._checkpoint = 0 job._workers_created = False job._workers = {i: value for i, value in enumerate(allocation)} await job.update_workers(allocation) assert not job._workers_created assert job._checkpoint == 0
async def test_adaptdl_job_checkpoint(ray_fix): job = RayAdaptDLJob(None, 0, 0) @ray.remote def worker(): while True: time.sleep(1) tasks = {i: worker.remote() for i in range(5)} job._worker_tasks = tasks job._running = True await job.force_worker_checkpoint() assert not job._running
async def test_adaptdl_job_create_workers(): job = RayAdaptDLJob({"CPU": 2}, 0, 0) async def mocked_handle_worker_failure(tasks): pass job._handle_worker_failure = mocked_handle_worker_failure await job._create_workers([ "adaptdl_virtual_node_0", "adaptdl_virtual_node_1", "adaptdl_virtual_node_2" ]) assert len(job._worker_tasks) == 3 for i in range(3): print(type(job._worker_tasks[i])) assert job._worker_tasks[i].called
async def test_controller_handle_report(): controller = Controller(100, 1) job = RayAdaptDLJob(None, 0, 1) controller._job = job controller.rescheduled = False async def mocked_reschedule(): controller.rescheduled = True controller._reschedule_jobs = mocked_reschedule asyncio.create_task(controller._reschedule_listener()) class MockedRequest: def __init__(self, body): self._body = body async def json(self): return self._body hints = {"some": "hints"} hints_json = MockedRequest(json.dumps(hints)) await controller._handle_report(hints_json) await asyncio.sleep(5) assert( controller.rescheduled and json.loads(job._last_metrics) == hints and id(job._last_metrics) != id(hints))
async def test_controller_spot_termination_handler(ray_fix): controller = Controller(100, 5) job = RayAdaptDLJob(None, 0, 0) controller._job = job controller.rescheduled = False async def mocked_reschedule(): controller.rescheduled = True controller._cluster = Cluster(None, 0) controller._reschedule_jobs = mocked_reschedule asyncio.create_task(controller._reschedule_listener()) controller._cluster.marked = None def mocked_mark_node_for_termination(ip): controller._cluster.marked = ip controller._cluster.mark_node_for_termination = \ mocked_mark_node_for_termination async def task(): return "some ip" async def wrapper(): awaitable_task = asyncio.create_task(task()) await controller._spot_termination_handler(awaitable_task) await wrapper() await asyncio.sleep(4) assert controller.rescheduled assert controller._cluster.marked == "some ip"
async def test_controller_reschedule_jobs(ray_fix): controller = Controller(100, 5) job = RayAdaptDLJob({"CPU": 1}, 0, 0) controller._job = job job.forced_checkpoint = False job.updated = 0 controller.handled_workers = [] async def mocked_handle_worker_failure(tasks): controller.handled_workers += tasks async def mocked_update_workers(allocation): await asyncio.sleep(3) job.updated += 1 if job._workers != allocation: job._workers = allocation return allocation return None controller._handle_worker_failure = mocked_handle_worker_failure job.update_workers = mocked_update_workers controller._cluster = Cluster(None, 0) controller._cluster.expanded = None async def mocked_expand_cluster(workers, allocation): controller._cluster.expanded = allocation return allocation controller._cluster.expand_cluster = mocked_expand_cluster async def wrapped_call(duration): await asyncio.sleep(duration) await controller._reschedule_jobs() await asyncio.wait_for( asyncio.gather( wrapped_call(0), wrapped_call(1), wrapped_call(2)), 15) await asyncio.sleep(4) assert job.updated == 3 # Default allocation assert controller.handled_workers == ['adaptdl_virtual_node_0'] assert controller._cluster.expanded == ['adaptdl_virtual_node_0']
async def test_controller_register_status(): controller = Controller(100, 5) job = RayAdaptDLJob(None, 0, 0) controller._job = job status = Status.RUNNING.value await controller.register_status(status) assert(job._status == Status.RUNNING and not job.completed.is_set()) status = Status.SUCCEEDED.value await controller.register_status(status) assert(job._status == Status.SUCCEEDED and job.completed.is_set())
async def test_controller_register_checkpoint(ray_fix): controller = Controller(100, 5) job = RayAdaptDLJob(None, 0, 0) controller._job = job checkpoint = "foo" checkpoint_received = await controller.register_checkpoint(checkpoint) assert checkpoint_received assert job._checkpoint_received assert job._checkpoint == "foo" assert ray.get(job._checkpoint_ref) == "foo"
async def test_controller_register_worker(ray_fix): controller = Controller(100, 5) job = RayAdaptDLJob(None, 0, 0) controller._job = job controller._spot_listener_tasks = {"some-ip": 1} controller.task_result = None async def mocked_spot_termination_handler(task): controller.task_result = ray.get(task) controller._spot_termination_handler = mocked_spot_termination_handler ip = ray._private.services.get_node_ip_address() await controller.register_worker(0, "some-ip") await controller.register_worker( 1, ray._private.services.get_node_ip_address()) await asyncio.sleep(1) assert job._workers[0] == "some-ip" assert job._workers[1] == ip assert controller.task_result == "a different ip"
async def test_adaptdl_job_register_hints(ray_fix): job = RayAdaptDLJob(None, 0, 0) job.register_hints("some hints") assert job._last_metrics == "some hints"
async def test_adaptdl_job_register_status(ray_fix): job = RayAdaptDLJob(None, 0, 0) status = Status.FAILED.value job.register_status(status) assert job._status == Status.FAILED assert job.completed.is_set() job = RayAdaptDLJob(None, 0, 0) status = Status.SUCCEEDED.value job.register_status(status) assert job._status == Status.SUCCEEDED assert job.completed.is_set() job = RayAdaptDLJob(None, 0, 0) status = Status.RUNNING.value job.register_status(status) assert job._status == Status.RUNNING assert not job.completed.is_set()