def test_preempt_workers_ask_master(dummy: bool, auto_ack: bool) -> None: with parallel.Execution(2) as pex: # Steal the automatically-created pex.distributed contexts, then test chief/worker serially # so we know they're not using distributed comms. @pex.run def distributed_contexts() -> core.DistributedContext: return pex.distributed # Test steps are identical for chief and worker. for dist in distributed_contexts: if not dummy: state, context = make_test_preempt_context( dist, core.PreemptMode.WorkersAskMaster) else: context = core.DummyPreemptContext( dist, core.PreemptMode.WorkersAskMaster) with context: assert context.should_preempt() is False if not dummy: # No ack preemption calls yet. state.mock_session.post.assert_not_called() # Send the preemption signal. state.preempt() wait_on_watcher(context) # Call again, to make sure we only ack once. assert context.should_preempt(auto_ack=auto_ack) is True if auto_ack: state.mock_session.post.assert_called_once() else: state.mock_session.post.assert_not_called()
def test_searcher_chief_only(dummy: bool) -> None: with parallel.Execution(2) as pex: @pex.run def do_test() -> None: if not dummy: searcher = make_test_searcher([5, 10, 15], pex.distributed) else: searcher = core.DummySearcherContext(dist=pex.distributed) with parallel.raises_when( pex.rank != 0, RuntimeError, match="searcher.operations.*chief" ): next(iter(searcher.operations(core.SearcherMode.ChiefOnly)))
def test_completion_check() -> None: with parallel.Execution(2) as pex: @pex.run def do_test() -> None: searcher = make_test_searcher([5], pex.distributed) ops = iter(searcher.operations()) next(ops) # Don't complete the op. with parallel.raises_when( pex.rank == 0, RuntimeError, match="must call op.report_completed" ): next(ops) # Wake up worker manually; it is hung waiting for the now-failed chief. if pex.rank == 0: pex.distributed.broadcast(10)
def test_check_started(dummy: bool) -> None: with parallel.Execution(2) as pex: @pex.run def do_test() -> None: if not dummy: state, context = make_test_preempt_context( pex.distributed, core.PreemptMode.WorkersAskChief) else: context = core.DummyPreemptContext( pex.distributed, core.PreemptMode.WorkersAskChief) with pytest.raises( RuntimeError, match="cannot call.*should_preempt.*before.*start"): context.should_preempt() with context: assert context.should_preempt() is False
def test_preempt_workers_ask_chief(dummy: bool, auto_ack: bool) -> None: with parallel.Execution(2) as pex: @pex.run def do_test() -> None: if not dummy: state, context = make_test_preempt_context( pex.distributed, core.PreemptMode.WorkersAskChief) else: context = core.DummyPreemptContext( pex.distributed, core.PreemptMode.WorkersAskChief) with context: if pex.rank == 0: # Check preemption. assert context.should_preempt() is False # Make sure the worker is receiving broadcasts. _ = pex.distributed.broadcast(False) if not dummy: # No ack preemption calls yet. state.mock_session.post.assert_not_called() # Send the preemption signal. state.preempt() wait_on_watcher(context) assert context.should_preempt( auto_ack=auto_ack) is True # Call again, to make sure we only ack once. assert context.should_preempt( auto_ack=auto_ack) is True if auto_ack: state.mock_session.post.assert_called_once() else: state.mock_session.post.assert_not_called() else: # Intercept the broadcast from the chief to make sure it's happening. out = pex.distributed.broadcast(None) assert out is False, out # Try receving from the chief. assert context.should_preempt() is False if not dummy: # The chief should send a True now. assert context.should_preempt() is True # Only the chief acknowledges the preemption signal. state.mock_session.post.assert_not_called()
def test_preempt_chief_only(dummy: bool, auto_ack: bool) -> None: with parallel.Execution(2) as pex: # Steal the automatically-created pex.distributed contexts, then test chief/worker serially # so we know they're not using distributed comms. @pex.run def distributed_contexts() -> core.DistributedContext: return pex.distributed # Test chief. if not dummy: state, context = make_test_preempt_context( distributed_contexts[0], core.PreemptMode.ChiefOnly) else: context = core.DummyPreemptContext(distributed_contexts[0], core.PreemptMode.ChiefOnly) with context: assert context.should_preempt() is False if not dummy: # No ack preemption calls yet. state.mock_session.post.assert_not_called() # Send the preemption signal. state.preempt() wait_on_watcher(context) assert context.should_preempt(auto_ack=auto_ack) is True # Call again, to make sure we only ack once. assert context.should_preempt(auto_ack=auto_ack) is True if auto_ack: state.mock_session.post.assert_called_once() else: state.mock_session.post.assert_not_called() # Test worker. if not dummy: state, context = make_test_preempt_context( distributed_contexts[1], core.PreemptMode.ChiefOnly) else: context = core.DummyPreemptContext(distributed_contexts[1], core.PreemptMode.ChiefOnly) with context: with pytest.raises(RuntimeError, match="should_preempt.*called from non-chief"): context.should_preempt()
def test_searcher_workers_ask_chief(dummy: bool) -> None: with parallel.Execution(2) as pex: @pex.run def searchers() -> core.SearcherContext: if not dummy: searcher = make_test_searcher([5, 10, 15], pex.distributed) else: searcher = core.DummySearcherContext(dist=pex.distributed) epochs_trained = 0 # Iterate through ops. for op in searcher.operations(): assert pex.distributed.allgather(op.length) == [op.length] * pex.size while epochs_trained < op.length: epochs_trained += 1 expect = [epochs_trained] * pex.size assert pex.distributed.allgather(epochs_trained) == expect with parallel.raises_when( pex.rank != 0, RuntimeError, match="op.report_progress.*chief" ): op.report_progress(epochs_trained) with parallel.raises_when( pex.rank != 0, RuntimeError, match="op.report_completed.*chief" ): op.report_completed(0.0) return searcher if not dummy: # Expect calls from chief: 15x progress, 4x completions chief = searchers[0] post_mock: Any = chief._session.post assert post_mock.call_count == 19, post_mock.call_args_list # The workers must not make any REST API calls at all. worker = searchers[1] post_mock = worker._session.post post_mock.assert_not_called()
def test_checkpoint_context(dummy: bool, mode: core.DownloadMode, tmp_path: pathlib.Path) -> None: ckpt_dir = tmp_path.joinpath("ckpt-dir") ckpt_dir.mkdir(exist_ok=True) with parallel.Execution(2) as pex: @pex.run def do_test() -> None: storage_manager = make_mock_storage_manager(tmp_path) if not dummy: session = mock.MagicMock() response = requests.Response() response.status_code = 200 session._do_request.return_value = response tbd_mgr = mock.MagicMock() checkpoint_context = core.CheckpointContext( pex.distributed, storage_manager, session=session, task_id="task-id", allocation_id="allocation-id", tbd_mgr=tbd_mgr, ) else: checkpoint_context = core.DummyCheckpointContext( pex.distributed, storage_manager) # Test upload. with parallel.raises_when( pex.distributed.rank == 1, RuntimeError, match="upload.*non-chief", ): checkpoint_context.upload(ckpt_dir, metadata={"steps_completed": 1}) if pex.rank == 0: storage_manager.upload.assert_called_once() storage_manager.upload.reset_mock() storage_manager._list_directory.assert_called_once() storage_manager._list_directory.reset_mock() if not dummy: session._do_request.assert_called_once() session._do_request.reset_mock() else: storage_manager.upload.assert_not_called() storage_manager._list_directory.assert_not_called() if not dummy: session._do_request.assert_not_called() tbd_mgr.sync.assert_not_called() # Test store_path. with parallel.raises_when( pex.distributed.rank == 1, RuntimeError, match=r"\.store_path.*non-chief", ): with checkpoint_context.store_path( metadata={"steps_completed": 1}) as _: pass if pex.rank == 0: storage_manager.store_path.assert_called_once() storage_manager.store_path.reset_mock() storage_manager._list_directory.assert_called_once() storage_manager._list_directory.reset_mock() if not dummy: session._do_request.assert_called_once() session._do_request.reset_mock() else: storage_manager.store_path.assert_not_called() storage_manager._list_directory.assert_not_called() if not dummy: session._do_request.assert_not_called() tbd_mgr.sync.assert_not_called() # Test download. unique_string = "arbitrary-string" if pex.distributed.rank == 0: checkpoint_context.download("ckpt-uuid", ckpt_dir, mode) if mode == core.DownloadMode.NoSharedDownload: # Send broadcast after download. _ = pex.distributed.broadcast_local(unique_string) else: if mode == core.DownloadMode.NoSharedDownload: # Receive broadcast before download, to ensure the download is not synchronized. recvd = pex.distributed.broadcast_local(unique_string) assert recvd == unique_string, recvd checkpoint_context.download("ckpt-uuid", ckpt_dir, mode) storage_manager.download.assert_called_once() storage_manager.download.reset_mock() # Test restore_path. if pex.distributed.rank == 0: with checkpoint_context.restore_path("ckpt-uuid", mode) as _: pass if mode == core.DownloadMode.NoSharedDownload: _ = pex.distributed.broadcast_local(unique_string) else: if mode == core.DownloadMode.NoSharedDownload: recvd = pex.distributed.broadcast_local(unique_string) assert recvd == unique_string, recvd with checkpoint_context.restore_path("ckpt-uuid", mode) as _: pass storage_manager.restore_path.assert_called_once() storage_manager.restore_path.reset_mock()
def test_distributed_context(cross_size: int, local_size: int, force_tcp: bool) -> None: size = cross_size * local_size # Make sure `make test` doesn't hang on macbook's default values. Avoid skipping on linux # because it's not a common default, and to avoid false positives in CI. if sys.platform == "darwin" and size == 16: import resource if resource.getrlimit(resource.RLIMIT_NOFILE)[0] < 1024: pytest.skip( "increase the open fd limit with `ulimit -n 1024` or greater to run this test" ) with parallel.Execution(size, local_size=local_size, make_distributed_context=False) as pex: @pex.run def contexts() -> core.DistributedContext: return core.DistributedContext( rank=pex.rank, size=pex.size, local_rank=pex.local_rank, local_size=pex.local_size, cross_rank=pex.cross_rank, cross_size=pex.cross_size, chief_ip="localhost", force_tcp=force_tcp, ) # Perform a broadcast. results = pex.run( lambda: contexts[pex.rank].broadcast(pex.rank)) # type: ignore assert results == [0] * size, "not all threads ran broadcast correctly" # Perform a local broadcast. results = pex.run(lambda: contexts[pex.rank].broadcast_local(pex.rank)) expect = [rank - (rank % local_size) for rank in range(size)] # type: Any assert results == expect, "not all threads ran broadcast_local correctly" # Perform a gather. results = pex.run( lambda: set(contexts[pex.rank].gather(pex.rank) or [])) chief = set(range(size)) expect = [ set(range(size)) if rank == 0 else set() for rank in range(size) ] assert results == [ chief ] + [set()] * (size - 1), "not all threads ran gather correctly" # Perform a local gather. results = pex.run( lambda: set(contexts[pex.rank].gather_local(pex.rank) or [])) expect = [ set(range(rank, rank + local_size)) if rank % local_size == 0 else set() for rank in range(size) ] assert results == expect, "not all threads ran gather correctly" # Perform an allgather. results = pex.run(lambda: set(contexts[pex.rank].allgather(pex.rank))) expect = set(range(size)) assert results == [expect ] * size, "not all threads ran allgather correctly" # Perform a local allgather. results = pex.run( lambda: set(contexts[pex.rank].allgather_local(pex.rank))) expect = [ set(range(cross_rank * local_size, (cross_rank + 1) * local_size)) for cross_rank, _ in itertools.product(range(cross_size), range(local_size)) ] assert results == expect, "not all threads ran allgather_local correctly" # Close all contexts. for context in contexts: context.close()
def test_distributed_context(cross_size: int, local_size: int, force_tcp: bool) -> None: size = cross_size * local_size with parallel.Execution(size, local_size=local_size, make_distributed_context=False) as pex: @pex.run def contexts() -> core.DistributedContext: return core.DistributedContext( rank=pex.rank, size=pex.size, local_rank=pex.local_rank, local_size=pex.local_size, cross_rank=pex.cross_rank, cross_size=pex.cross_size, chief_ip="localhost", force_tcp=force_tcp, ) # Perform a broadcast. results = pex.run( lambda: contexts[pex.rank].broadcast(pex.rank)) # type: ignore assert results == [0] * size, "not all threads ran broadcast correctly" # Perform a local broadcast. results = pex.run(lambda: contexts[pex.rank].broadcast_local(pex.rank)) expect = [rank - (rank % local_size) for rank in range(size)] # type: Any assert results == expect, "not all threads ran broadcast_local correctly" # Perform a gather. results = pex.run( lambda: set(contexts[pex.rank].gather(pex.rank) or [])) chief = set(range(size)) expect = [ set(range(size)) if rank == 0 else set() for rank in range(size) ] assert results == [ chief ] + [set()] * (size - 1), "not all threads ran gather correctly" # Perform a local gather. results = pex.run( lambda: set(contexts[pex.rank].gather_local(pex.rank) or [])) expect = [ set(range(rank, rank + local_size)) if rank % local_size == 0 else set() for rank in range(size) ] assert results == expect, "not all threads ran gather correctly" # Perform an allgather. results = pex.run(lambda: set(contexts[pex.rank].allgather(pex.rank))) expect = set(range(size)) assert results == [expect ] * size, "not all threads ran allgather correctly" # Perform a local allgather. results = pex.run( lambda: set(contexts[pex.rank].allgather_local(pex.rank))) expect = [ set(range(cross_rank * local_size, (cross_rank + 1) * local_size)) for cross_rank, _ in itertools.product(range(cross_size), range(local_size)) ] assert results == expect, "not all threads ran allgather_local correctly" # Close all contexts. for context in contexts: context.close()