コード例 #1
0
async def test_adaptdl_job_register_checkpoint(ray_fix):
    job = RayAdaptDLJob(None, 0, 0)
    checkpoint = "foo"
    assert not job._checkpoint_received.is_set()
    job.register_checkpoint(checkpoint)
    assert job._checkpoint == "foo"
    assert job._checkpoint_received.is_set()
    assert ray.get(job._checkpoint_ref) == "foo"
コード例 #2
0
async def test_fetch_metrics():
    hints = {
        "gradParams": {"norm": 3.0, "var": 4.0},
        "perfParams": {
            'alpha_c': 0, 'beta_c': 0, 'alpha_n': 0,
            'beta_n': 0, 'alpha_r': 0, 'beta_r': 0, 'gamma': 0}}
    job = RayAdaptDLJob(None, 0, 0)
    job._last_metrics = hints
    job._fetch_metrics()
コード例 #3
0
async def test_adaptdl_job_update_workers():
    job = RayAdaptDLJob({"CPU": 2}, 0, 3)
    allocation = [
        "adaptdl_virtual_node_0", "adaptdl_virtual_node_1",
        "adaptdl_virtual_node_2"
    ]

    async def mocked_force_worker_checkpoint():
        print("forcing checkpoint")
        job._checkpoint_received = True
        job._checkpoint = 3

    async def mocked_create_workers(allocation):
        job._workers_created = True
        job._worker_ips = allocation

    job.force_worker_checkpoint = mocked_force_worker_checkpoint
    job._workers_created = False
    job._worker_ips = None
    job._create_workers = mocked_create_workers

    await job.update_workers(allocation)
    assert job._checkpoint == 3
    assert job._workers_created
    assert job._worker_ips == allocation

    job._checkpoint = 0
    job._workers_created = False
    job._workers = {i: value for i, value in enumerate(allocation)}
    await job.update_workers(allocation)
    assert not job._workers_created
    assert job._checkpoint == 0
コード例 #4
0
async def test_adaptdl_job_checkpoint(ray_fix):
    job = RayAdaptDLJob(None, 0, 0)

    @ray.remote
    def worker():
        while True:
            time.sleep(1)
    tasks = {i: worker.remote() for i in range(5)}
    job._worker_tasks = tasks
    job._running = True
    await job.force_worker_checkpoint()
    assert not job._running
コード例 #5
0
async def test_adaptdl_job_create_workers():
    job = RayAdaptDLJob({"CPU": 2}, 0, 0)

    async def mocked_handle_worker_failure(tasks):
        pass

    job._handle_worker_failure = mocked_handle_worker_failure
    await job._create_workers([
        "adaptdl_virtual_node_0", "adaptdl_virtual_node_1",
        "adaptdl_virtual_node_2"
    ])
    assert len(job._worker_tasks) == 3
    for i in range(3):
        print(type(job._worker_tasks[i]))
        assert job._worker_tasks[i].called
コード例 #6
0
async def test_controller_handle_report():
    controller = Controller(100, 1)
    job = RayAdaptDLJob(None, 0, 1)
    controller._job = job
    controller.rescheduled = False

    async def mocked_reschedule():
        controller.rescheduled = True

    controller._reschedule_jobs = mocked_reschedule
    asyncio.create_task(controller._reschedule_listener())

    class MockedRequest:
        def __init__(self, body):
            self._body = body

        async def json(self):
            return self._body
    hints = {"some": "hints"}
    hints_json = MockedRequest(json.dumps(hints))
    await controller._handle_report(hints_json)

    await asyncio.sleep(5)

    assert(
        controller.rescheduled and
        json.loads(job._last_metrics) == hints and
        id(job._last_metrics) != id(hints))
コード例 #7
0
async def test_controller_spot_termination_handler(ray_fix):
    controller = Controller(100, 5)
    job = RayAdaptDLJob(None, 0, 0)
    controller._job = job
    controller.rescheduled = False

    async def mocked_reschedule():
        controller.rescheduled = True

    controller._cluster = Cluster(None, 0)
    controller._reschedule_jobs = mocked_reschedule
    asyncio.create_task(controller._reschedule_listener())

    controller._cluster.marked = None

    def mocked_mark_node_for_termination(ip):
        controller._cluster.marked = ip

    controller._cluster.mark_node_for_termination = \
        mocked_mark_node_for_termination

    async def task():
        return "some ip"

    async def wrapper():
        awaitable_task = asyncio.create_task(task())
        await controller._spot_termination_handler(awaitable_task)

    await wrapper()
    await asyncio.sleep(4)
    assert controller.rescheduled
    assert controller._cluster.marked == "some ip"
コード例 #8
0
async def test_controller_reschedule_jobs(ray_fix):
    controller = Controller(100, 5)
    job = RayAdaptDLJob({"CPU": 1}, 0, 0)
    controller._job = job
    job.forced_checkpoint = False
    job.updated = 0
    controller.handled_workers = []

    async def mocked_handle_worker_failure(tasks):
        controller.handled_workers += tasks

    async def mocked_update_workers(allocation):
        await asyncio.sleep(3)
        job.updated += 1
        if job._workers != allocation:
            job._workers = allocation
            return allocation
        return None

    controller._handle_worker_failure = mocked_handle_worker_failure
    job.update_workers = mocked_update_workers

    controller._cluster = Cluster(None, 0)
    controller._cluster.expanded = None

    async def mocked_expand_cluster(workers, allocation):
        controller._cluster.expanded = allocation
        return allocation

    controller._cluster.expand_cluster = mocked_expand_cluster

    async def wrapped_call(duration):
        await asyncio.sleep(duration)
        await controller._reschedule_jobs()

    await asyncio.wait_for(
        asyncio.gather(
            wrapped_call(0), wrapped_call(1), wrapped_call(2)),
        15)

    await asyncio.sleep(4)
    assert job.updated == 3

    # Default allocation
    assert controller.handled_workers == ['adaptdl_virtual_node_0']
    assert controller._cluster.expanded == ['adaptdl_virtual_node_0']
コード例 #9
0
async def test_controller_register_status():
    controller = Controller(100, 5)
    job = RayAdaptDLJob(None, 0, 0)
    controller._job = job
    status = Status.RUNNING.value
    await controller.register_status(status)
    assert(job._status == Status.RUNNING and not job.completed.is_set())
    status = Status.SUCCEEDED.value
    await controller.register_status(status)
    assert(job._status == Status.SUCCEEDED and job.completed.is_set())
コード例 #10
0
async def test_controller_register_checkpoint(ray_fix):
    controller = Controller(100, 5)
    job = RayAdaptDLJob(None, 0, 0)
    controller._job = job
    checkpoint = "foo"
    checkpoint_received = await controller.register_checkpoint(checkpoint)
    assert checkpoint_received
    assert job._checkpoint_received
    assert job._checkpoint == "foo"
    assert ray.get(job._checkpoint_ref) == "foo"
コード例 #11
0
async def test_controller_register_worker(ray_fix):
    controller = Controller(100, 5)
    job = RayAdaptDLJob(None, 0, 0)
    controller._job = job
    controller._spot_listener_tasks = {"some-ip": 1}

    controller.task_result = None

    async def mocked_spot_termination_handler(task):
        controller.task_result = ray.get(task)

    controller._spot_termination_handler = mocked_spot_termination_handler

    ip = ray._private.services.get_node_ip_address()

    await controller.register_worker(0, "some-ip")
    await controller.register_worker(
        1, ray._private.services.get_node_ip_address())

    await asyncio.sleep(1)

    assert job._workers[0] == "some-ip"
    assert job._workers[1] == ip
    assert controller.task_result == "a different ip"
コード例 #12
0
async def test_adaptdl_job_register_hints(ray_fix):
    job = RayAdaptDLJob(None, 0, 0)
    job.register_hints("some hints")
    assert job._last_metrics == "some hints"
コード例 #13
0
async def test_adaptdl_job_register_status(ray_fix):
    job = RayAdaptDLJob(None, 0, 0)
    status = Status.FAILED.value
    job.register_status(status)
    assert job._status == Status.FAILED
    assert job.completed.is_set()

    job = RayAdaptDLJob(None, 0, 0)
    status = Status.SUCCEEDED.value
    job.register_status(status)
    assert job._status == Status.SUCCEEDED
    assert job.completed.is_set()

    job = RayAdaptDLJob(None, 0, 0)
    status = Status.RUNNING.value
    job.register_status(status)
    assert job._status == Status.RUNNING
    assert not job.completed.is_set()