def test_preempt_workers_ask_master(dummy: bool, auto_ack: bool) -> None:
    with parallel.Execution(2) as pex:

        # Steal the automatically-created pex.distributed contexts, then test chief/worker serially
        # so we know they're not using distributed comms.
        @pex.run
        def distributed_contexts() -> core.DistributedContext:
            return pex.distributed

        # Test steps are identical for chief and worker.
        for dist in distributed_contexts:
            if not dummy:
                state, context = make_test_preempt_context(
                    dist, core.PreemptMode.WorkersAskMaster)
            else:
                context = core.DummyPreemptContext(
                    dist, core.PreemptMode.WorkersAskMaster)
            with context:
                assert context.should_preempt() is False
                if not dummy:
                    # No ack preemption calls yet.
                    state.mock_session.post.assert_not_called()
                    # Send the preemption signal.
                    state.preempt()
                    wait_on_watcher(context)
                    # Call again, to make sure we only ack once.
                    assert context.should_preempt(auto_ack=auto_ack) is True
                    if auto_ack:
                        state.mock_session.post.assert_called_once()
                    else:
                        state.mock_session.post.assert_not_called()
def test_searcher_chief_only(dummy: bool) -> None:
    with parallel.Execution(2) as pex:

        @pex.run
        def do_test() -> None:
            if not dummy:
                searcher = make_test_searcher([5, 10, 15], pex.distributed)
            else:
                searcher = core.DummySearcherContext(dist=pex.distributed)

            with parallel.raises_when(
                pex.rank != 0, RuntimeError, match="searcher.operations.*chief"
            ):
                next(iter(searcher.operations(core.SearcherMode.ChiefOnly)))
def test_completion_check() -> None:
    with parallel.Execution(2) as pex:

        @pex.run
        def do_test() -> None:
            searcher = make_test_searcher([5], pex.distributed)

            ops = iter(searcher.operations())
            next(ops)
            # Don't complete the op.
            with parallel.raises_when(
                pex.rank == 0, RuntimeError, match="must call op.report_completed"
            ):
                next(ops)
            # Wake up worker manually; it is hung waiting for the now-failed chief.
            if pex.rank == 0:
                pex.distributed.broadcast(10)
def test_check_started(dummy: bool) -> None:
    with parallel.Execution(2) as pex:

        @pex.run
        def do_test() -> None:
            if not dummy:
                state, context = make_test_preempt_context(
                    pex.distributed, core.PreemptMode.WorkersAskChief)
            else:
                context = core.DummyPreemptContext(
                    pex.distributed, core.PreemptMode.WorkersAskChief)

            with pytest.raises(
                    RuntimeError,
                    match="cannot call.*should_preempt.*before.*start"):
                context.should_preempt()
            with context:
                assert context.should_preempt() is False
def test_preempt_workers_ask_chief(dummy: bool, auto_ack: bool) -> None:
    with parallel.Execution(2) as pex:

        @pex.run
        def do_test() -> None:
            if not dummy:
                state, context = make_test_preempt_context(
                    pex.distributed, core.PreemptMode.WorkersAskChief)
            else:
                context = core.DummyPreemptContext(
                    pex.distributed, core.PreemptMode.WorkersAskChief)

            with context:
                if pex.rank == 0:
                    # Check preemption.
                    assert context.should_preempt() is False
                    # Make sure the worker is receiving broadcasts.
                    _ = pex.distributed.broadcast(False)
                    if not dummy:
                        # No ack preemption calls yet.
                        state.mock_session.post.assert_not_called()
                        # Send the preemption signal.
                        state.preempt()
                        wait_on_watcher(context)
                        assert context.should_preempt(
                            auto_ack=auto_ack) is True
                        # Call again, to make sure we only ack once.
                        assert context.should_preempt(
                            auto_ack=auto_ack) is True
                        if auto_ack:
                            state.mock_session.post.assert_called_once()
                        else:
                            state.mock_session.post.assert_not_called()
                else:
                    # Intercept the broadcast from the chief to make sure it's happening.
                    out = pex.distributed.broadcast(None)
                    assert out is False, out
                    # Try receving from the chief.
                    assert context.should_preempt() is False
                    if not dummy:
                        # The chief should send a True now.
                        assert context.should_preempt() is True
                        # Only the chief acknowledges the preemption signal.
                        state.mock_session.post.assert_not_called()
def test_preempt_chief_only(dummy: bool, auto_ack: bool) -> None:
    with parallel.Execution(2) as pex:

        # Steal the automatically-created pex.distributed contexts, then test chief/worker serially
        # so we know they're not using distributed comms.
        @pex.run
        def distributed_contexts() -> core.DistributedContext:
            return pex.distributed

        # Test chief.
        if not dummy:
            state, context = make_test_preempt_context(
                distributed_contexts[0], core.PreemptMode.ChiefOnly)
        else:
            context = core.DummyPreemptContext(distributed_contexts[0],
                                               core.PreemptMode.ChiefOnly)
        with context:
            assert context.should_preempt() is False
            if not dummy:
                # No ack preemption calls yet.
                state.mock_session.post.assert_not_called()
                # Send the preemption signal.
                state.preempt()
                wait_on_watcher(context)
                assert context.should_preempt(auto_ack=auto_ack) is True
                # Call again, to make sure we only ack once.
                assert context.should_preempt(auto_ack=auto_ack) is True
                if auto_ack:
                    state.mock_session.post.assert_called_once()
                else:
                    state.mock_session.post.assert_not_called()

        # Test worker.
        if not dummy:
            state, context = make_test_preempt_context(
                distributed_contexts[1], core.PreemptMode.ChiefOnly)
        else:
            context = core.DummyPreemptContext(distributed_contexts[1],
                                               core.PreemptMode.ChiefOnly)
        with context:
            with pytest.raises(RuntimeError,
                               match="should_preempt.*called from non-chief"):
                context.should_preempt()
def test_searcher_workers_ask_chief(dummy: bool) -> None:
    with parallel.Execution(2) as pex:

        @pex.run
        def searchers() -> core.SearcherContext:
            if not dummy:
                searcher = make_test_searcher([5, 10, 15], pex.distributed)
            else:
                searcher = core.DummySearcherContext(dist=pex.distributed)
            epochs_trained = 0
            # Iterate through ops.
            for op in searcher.operations():
                assert pex.distributed.allgather(op.length) == [op.length] * pex.size
                while epochs_trained < op.length:
                    epochs_trained += 1
                    expect = [epochs_trained] * pex.size
                    assert pex.distributed.allgather(epochs_trained) == expect
                    with parallel.raises_when(
                        pex.rank != 0, RuntimeError, match="op.report_progress.*chief"
                    ):
                        op.report_progress(epochs_trained)
                with parallel.raises_when(
                    pex.rank != 0, RuntimeError, match="op.report_completed.*chief"
                ):
                    op.report_completed(0.0)

            return searcher

        if not dummy:
            # Expect calls from chief: 15x progress, 4x completions
            chief = searchers[0]
            post_mock: Any = chief._session.post
            assert post_mock.call_count == 19, post_mock.call_args_list

            # The workers must not make any REST API calls at all.
            worker = searchers[1]
            post_mock = worker._session.post
            post_mock.assert_not_called()
def test_checkpoint_context(dummy: bool, mode: core.DownloadMode,
                            tmp_path: pathlib.Path) -> None:
    ckpt_dir = tmp_path.joinpath("ckpt-dir")
    ckpt_dir.mkdir(exist_ok=True)
    with parallel.Execution(2) as pex:

        @pex.run
        def do_test() -> None:
            storage_manager = make_mock_storage_manager(tmp_path)
            if not dummy:
                session = mock.MagicMock()
                response = requests.Response()
                response.status_code = 200
                session._do_request.return_value = response
                tbd_mgr = mock.MagicMock()
                checkpoint_context = core.CheckpointContext(
                    pex.distributed,
                    storage_manager,
                    session=session,
                    task_id="task-id",
                    allocation_id="allocation-id",
                    tbd_mgr=tbd_mgr,
                )
            else:
                checkpoint_context = core.DummyCheckpointContext(
                    pex.distributed, storage_manager)

            # Test upload.
            with parallel.raises_when(
                    pex.distributed.rank == 1,
                    RuntimeError,
                    match="upload.*non-chief",
            ):
                checkpoint_context.upload(ckpt_dir,
                                          metadata={"steps_completed": 1})
            if pex.rank == 0:
                storage_manager.upload.assert_called_once()
                storage_manager.upload.reset_mock()
                storage_manager._list_directory.assert_called_once()
                storage_manager._list_directory.reset_mock()
                if not dummy:
                    session._do_request.assert_called_once()
                    session._do_request.reset_mock()
            else:
                storage_manager.upload.assert_not_called()
                storage_manager._list_directory.assert_not_called()
                if not dummy:
                    session._do_request.assert_not_called()
                    tbd_mgr.sync.assert_not_called()

            # Test store_path.
            with parallel.raises_when(
                    pex.distributed.rank == 1,
                    RuntimeError,
                    match=r"\.store_path.*non-chief",
            ):
                with checkpoint_context.store_path(
                        metadata={"steps_completed": 1}) as _:
                    pass
            if pex.rank == 0:
                storage_manager.store_path.assert_called_once()
                storage_manager.store_path.reset_mock()
                storage_manager._list_directory.assert_called_once()
                storage_manager._list_directory.reset_mock()
                if not dummy:
                    session._do_request.assert_called_once()
                    session._do_request.reset_mock()
            else:
                storage_manager.store_path.assert_not_called()
                storage_manager._list_directory.assert_not_called()
                if not dummy:
                    session._do_request.assert_not_called()
                    tbd_mgr.sync.assert_not_called()

            # Test download.
            unique_string = "arbitrary-string"
            if pex.distributed.rank == 0:
                checkpoint_context.download("ckpt-uuid", ckpt_dir, mode)
                if mode == core.DownloadMode.NoSharedDownload:
                    # Send broadcast after download.
                    _ = pex.distributed.broadcast_local(unique_string)
            else:
                if mode == core.DownloadMode.NoSharedDownload:
                    # Receive broadcast before download, to ensure the download is not synchronized.
                    recvd = pex.distributed.broadcast_local(unique_string)
                    assert recvd == unique_string, recvd
                checkpoint_context.download("ckpt-uuid", ckpt_dir, mode)
            storage_manager.download.assert_called_once()
            storage_manager.download.reset_mock()

            # Test restore_path.
            if pex.distributed.rank == 0:
                with checkpoint_context.restore_path("ckpt-uuid", mode) as _:
                    pass
                if mode == core.DownloadMode.NoSharedDownload:
                    _ = pex.distributed.broadcast_local(unique_string)
            else:
                if mode == core.DownloadMode.NoSharedDownload:
                    recvd = pex.distributed.broadcast_local(unique_string)
                    assert recvd == unique_string, recvd
                with checkpoint_context.restore_path("ckpt-uuid", mode) as _:
                    pass
            storage_manager.restore_path.assert_called_once()
            storage_manager.restore_path.reset_mock()
Beispiel #9
0
def test_distributed_context(cross_size: int, local_size: int,
                             force_tcp: bool) -> None:
    size = cross_size * local_size

    # Make sure `make test` doesn't hang on macbook's default values.  Avoid skipping on linux
    # because it's not a common default, and to avoid false positives in CI.
    if sys.platform == "darwin" and size == 16:
        import resource

        if resource.getrlimit(resource.RLIMIT_NOFILE)[0] < 1024:
            pytest.skip(
                "increase the open fd limit with `ulimit -n 1024` or greater to run this test"
            )

    with parallel.Execution(size,
                            local_size=local_size,
                            make_distributed_context=False) as pex:

        @pex.run
        def contexts() -> core.DistributedContext:
            return core.DistributedContext(
                rank=pex.rank,
                size=pex.size,
                local_rank=pex.local_rank,
                local_size=pex.local_size,
                cross_rank=pex.cross_rank,
                cross_size=pex.cross_size,
                chief_ip="localhost",
                force_tcp=force_tcp,
            )

        # Perform a broadcast.
        results = pex.run(
            lambda: contexts[pex.rank].broadcast(pex.rank))  # type: ignore
        assert results == [0] * size, "not all threads ran broadcast correctly"

        # Perform a local broadcast.
        results = pex.run(lambda: contexts[pex.rank].broadcast_local(pex.rank))
        expect = [rank - (rank % local_size)
                  for rank in range(size)]  # type: Any

        assert results == expect, "not all threads ran broadcast_local correctly"

        # Perform a gather.
        results = pex.run(
            lambda: set(contexts[pex.rank].gather(pex.rank) or []))
        chief = set(range(size))
        expect = [
            set(range(size)) if rank == 0 else set() for rank in range(size)
        ]
        assert results == [
            chief
        ] + [set()] * (size - 1), "not all threads ran gather correctly"

        # Perform a local gather.
        results = pex.run(
            lambda: set(contexts[pex.rank].gather_local(pex.rank) or []))
        expect = [
            set(range(rank, rank + local_size)) if rank %
            local_size == 0 else set() for rank in range(size)
        ]
        assert results == expect, "not all threads ran gather correctly"

        # Perform an allgather.
        results = pex.run(lambda: set(contexts[pex.rank].allgather(pex.rank)))
        expect = set(range(size))
        assert results == [expect
                           ] * size, "not all threads ran allgather correctly"

        # Perform a local allgather.
        results = pex.run(
            lambda: set(contexts[pex.rank].allgather_local(pex.rank)))
        expect = [
            set(range(cross_rank * local_size, (cross_rank + 1) * local_size))
            for cross_rank, _ in itertools.product(range(cross_size),
                                                   range(local_size))
        ]
        assert results == expect, "not all threads ran allgather_local correctly"

        # Close all contexts.
        for context in contexts:
            context.close()
def test_distributed_context(cross_size: int, local_size: int,
                             force_tcp: bool) -> None:
    size = cross_size * local_size

    with parallel.Execution(size,
                            local_size=local_size,
                            make_distributed_context=False) as pex:

        @pex.run
        def contexts() -> core.DistributedContext:
            return core.DistributedContext(
                rank=pex.rank,
                size=pex.size,
                local_rank=pex.local_rank,
                local_size=pex.local_size,
                cross_rank=pex.cross_rank,
                cross_size=pex.cross_size,
                chief_ip="localhost",
                force_tcp=force_tcp,
            )

        # Perform a broadcast.
        results = pex.run(
            lambda: contexts[pex.rank].broadcast(pex.rank))  # type: ignore
        assert results == [0] * size, "not all threads ran broadcast correctly"

        # Perform a local broadcast.
        results = pex.run(lambda: contexts[pex.rank].broadcast_local(pex.rank))
        expect = [rank - (rank % local_size)
                  for rank in range(size)]  # type: Any

        assert results == expect, "not all threads ran broadcast_local correctly"

        # Perform a gather.
        results = pex.run(
            lambda: set(contexts[pex.rank].gather(pex.rank) or []))
        chief = set(range(size))
        expect = [
            set(range(size)) if rank == 0 else set() for rank in range(size)
        ]
        assert results == [
            chief
        ] + [set()] * (size - 1), "not all threads ran gather correctly"

        # Perform a local gather.
        results = pex.run(
            lambda: set(contexts[pex.rank].gather_local(pex.rank) or []))
        expect = [
            set(range(rank, rank + local_size)) if rank %
            local_size == 0 else set() for rank in range(size)
        ]
        assert results == expect, "not all threads ran gather correctly"

        # Perform an allgather.
        results = pex.run(lambda: set(contexts[pex.rank].allgather(pex.rank)))
        expect = set(range(size))
        assert results == [expect
                           ] * size, "not all threads ran allgather correctly"

        # Perform a local allgather.
        results = pex.run(
            lambda: set(contexts[pex.rank].allgather_local(pex.rank)))
        expect = [
            set(range(cross_rank * local_size, (cross_rank + 1) * local_size))
            for cross_rank, _ in itertools.product(range(cross_size),
                                                   range(local_size))
        ]
        assert results == expect, "not all threads ran allgather_local correctly"

        # Close all contexts.
        for context in contexts:
            context.close()