예제 #1
0
    def test_run_removes_from_participants(self) -> None:
        for complete, last_call_deadline in [(False, self._now), (True, None)]:
            self._state = _RendezvousState()

            self._add_participants(2, self._state)

            self._state.participants[self._node] = 0

            self._state.last_heartbeats[self._node] = self._now

            self._state.complete = complete
            self._state.deadline = last_call_deadline

            self._state.round = 1

            expected_state = _RendezvousState()

            self._add_participants(2, expected_state)

            expected_state.complete = complete
            expected_state.deadline = last_call_deadline

            expected_state.round = 1

            with self.subTest(complete=complete):
                self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS,
                                    expected_state)

                self._mock_state_holder.reset_mock()
예제 #2
0
    def test_run_adds_to_participants_and_completes_rendezvous_if_max_nodes_is_reached(
        self, ) -> None:
        for min_max_nodes_equal in [False, True]:
            for num_participants in range(3):
                rank = num_participants

                self._state = _RendezvousState()

                self._add_participants(num_participants, self._state)

                self._state.wait_list.add(self._node)

                self._state.deadline = self._now + self._timeout.last_call

                expected_state = _RendezvousState()

                self._add_participants(num_participants,
                                       expected_state,
                                       ranked=True)

                expected_state.participants[self._node] = rank

                expected_state.last_heartbeats[self._node] = self._now

                expected_state.complete = True
                expected_state.deadline = None

                with self.subTest(num_participants=num_participants):
                    self._min_nodes = num_participants + 1 if min_max_nodes_equal else 0
                    self._max_nodes = num_participants + 1

                    self._assert_action(_Action.ADD_TO_PARTICIPANTS,
                                        expected_state)

                    self._mock_state_holder.reset_mock()
예제 #3
0
    def test_run_adds_to_participants_and_starts_last_call_if_min_nodes_is_reached(
            self) -> None:
        for num_participants in range(3):
            self._state = _RendezvousState()

            self._add_participants(num_participants, self._state)

            self._state.wait_list.add(self._node)

            expected_state = _RendezvousState()

            self._add_participants(num_participants, expected_state)

            expected_state.participants[self._node] = 0

            expected_state.last_heartbeats[self._node] = self._now

            expected_state.deadline = self._now + self._timeout.last_call

            with self.subTest(num_participants=num_participants):
                self._min_nodes = num_participants + 1
                self._max_nodes = num_participants + 2

                self._assert_action(_Action.ADD_TO_PARTICIPANTS,
                                    expected_state)

                self._mock_state_holder.reset_mock()
예제 #4
0
    def setUp(self) -> None:
        self._node = _NodeDesc("this_node", 1, 1)

        self._min_nodes = 1
        self._max_nodes = 2

        self._keep_alive_interval = timedelta(seconds=30)

        self._state = _RendezvousState()
        self._state.participants[_NodeDesc("dummy1", 1, 1)] = 1

        self._now = datetime(2000, 1, 1, hour=0, minute=0)

        self._deadline = 10

        self._datetime_patch = patch(
            "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime")

        mock_datetime = self._datetime_patch.start()
        mock_datetime.utcnow.return_value = self._now

        self._time_patch = patch(
            "torch.distributed.elastic.rendezvous.dynamic_rendezvous.time")

        mock_time = self._time_patch.start()
        mock_time.monotonic.return_value = self._deadline
예제 #5
0
    def test_encoded_size_is_within_expected_limit(self) -> None:
        state = _RendezvousState()
        state.round = 1
        state.complete = True
        state.deadline = datetime.utcnow()
        state.closed = True

        # fmt: off
        expected_max_sizes = (
            (   5,    2 * (2 ** 10),),  #    10 machines <=   2KB  # noqa: E201, E241, E262
            (  50,   16 * (2 ** 10),),  #   100 machines <=  16KB  # noqa: E201, E241, E262
            ( 500,  160 * (2 ** 10),),  #  1000 machines <= 160KB  # noqa: E201, E241, E262
            (5000, 1600 * (2 ** 10),),  # 10000 machines <= 1.6MB  # noqa: E201, E241, E262
        )
        # fmt: on

        for num_nodes, max_byte_size in expected_max_sizes:
            with self.subTest(num_nodes=num_nodes, max_byte_size=max_byte_size):
                for i in range(num_nodes):
                    node_running = _NodeDesc(f"dummy{i}.dummy1-dummy1-dummy1-dummy1.com", 12345, i)
                    node_waiting = _NodeDesc(f"dummy{i}.dummy2-dummy2-dummy2-dummy2.com", 67890, i)

                    state.participants[node_running] = i

                    state.wait_list.add(node_waiting)

                    state.last_heartbeats[node_running] = datetime.utcnow()
                    state.last_heartbeats[node_waiting] = datetime.utcnow()

                bits = pickle.dumps(state)

                base64_bits = b64encode(bits)

                self.assertLessEqual(len(base64_bits), max_byte_size)
예제 #6
0
    def test_run_removes_from_waitlist(self) -> None:
        self._state.wait_list.add(self._node)

        self._state.last_heartbeats[self._node] = self._now

        expected_state = _RendezvousState()

        self._assert_action(_Action.REMOVE_FROM_WAIT_LIST, expected_state)
예제 #7
0
    def test_run_adds_to_waitlist(self) -> None:
        expected_state = _RendezvousState()

        expected_state.wait_list.add(self._node)

        expected_state.last_heartbeats[self._node] = self._now

        self._assert_action(_Action.ADD_TO_WAIT_LIST, expected_state)
예제 #8
0
    def test_run_adds_to_participants(self) -> None:
        expected_state = _RendezvousState()

        expected_state.participants[self._node] = 0

        expected_state.last_heartbeats[self._node] = self._now

        self._min_nodes = 2
        self._max_nodes = 2

        self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state)
예제 #9
0
    def test_run_removes_from_participants_and_moves_to_next_round_if_node_is_last_participant(
        self, ) -> None:
        self._state.participants[self._node] = 0

        self._state.last_heartbeats[self._node] = self._now

        self._state.complete = True

        self._state.round = 1

        expected_state = _RendezvousState()

        expected_state.complete = False

        expected_state.round = 2

        self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state)
예제 #10
0
    def test_run_removes_from_participants_and_clears_last_call_if_rendezvous_has_less_than_min_nodes(
        self, ) -> None:
        self._add_participants(2, self._state)

        self._state.participants[self._node] = 0

        self._state.last_heartbeats[self._node] = self._now

        self._state.deadline = self._now

        expected_state = _RendezvousState()

        self._add_participants(2, expected_state)

        self._min_nodes = 3
        self._max_nodes = 4

        self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state)
예제 #11
0
    def _create_state(self) -> _RendezvousState:
        state = _RendezvousState()
        state.round = 999
        state.complete = True
        state.deadline = self._now
        state.closed = True
        state.participants = {
            _NodeDesc("dummy1", 1, 1): 0,
            _NodeDesc("dummy2", 1, 1): 1,
            _NodeDesc("dummy3", 1, 1): 2,
        }
        state.wait_list = {
            _NodeDesc("dummy4", 1, 1),
            _NodeDesc("dummy5", 1, 1),
        }
        state.last_heartbeats = {
            _NodeDesc("dummy1", 1, 1): self._now,
            _NodeDesc("dummy2", 1, 1): self._now - timedelta(seconds=15),
            _NodeDesc("dummy3", 1, 1): self._now - timedelta(seconds=30),
            _NodeDesc("dummy4", 1, 1): self._now - timedelta(seconds=60),
            _NodeDesc("dummy5", 1, 1): self._now - timedelta(seconds=90),
        }

        return state
예제 #12
0
    def test_run_marks_rendezvous_closed(self) -> None:
        expected_state = _RendezvousState()

        expected_state.closed = True

        self._assert_action(_Action.MARK_RENDEZVOUS_CLOSED, expected_state)
예제 #13
0
    def test_run_keeps_alive(self) -> None:
        expected_state = _RendezvousState()

        expected_state.last_heartbeats[self._node] = self._now

        self._assert_action(_Action.KEEP_ALIVE, expected_state)
예제 #14
0
 def assert_state_empty(self, actual: _RendezvousState) -> None:
     self.assertDictEqual(vars(actual), vars(_RendezvousState()))
예제 #15
0
 def __init__(self) -> None:
     self._state = _RendezvousState()
     self._dirty = None