def test_run_removes_from_participants(self) -> None: for complete, last_call_deadline in [(False, self._now), (True, None)]: self._state = _RendezvousState() self._add_participants(2, self._state) self._state.participants[self._node] = 0 self._state.last_heartbeats[self._node] = self._now self._state.complete = complete self._state.deadline = last_call_deadline self._state.round = 1 expected_state = _RendezvousState() self._add_participants(2, expected_state) expected_state.complete = complete expected_state.deadline = last_call_deadline expected_state.round = 1 with self.subTest(complete=complete): self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state) self._mock_state_holder.reset_mock()
def test_run_adds_to_participants_and_completes_rendezvous_if_max_nodes_is_reached( self, ) -> None: for min_max_nodes_equal in [False, True]: for num_participants in range(3): rank = num_participants self._state = _RendezvousState() self._add_participants(num_participants, self._state) self._state.wait_list.add(self._node) self._state.deadline = self._now + self._timeout.last_call expected_state = _RendezvousState() self._add_participants(num_participants, expected_state, ranked=True) expected_state.participants[self._node] = rank expected_state.last_heartbeats[self._node] = self._now expected_state.complete = True expected_state.deadline = None with self.subTest(num_participants=num_participants): self._min_nodes = num_participants + 1 if min_max_nodes_equal else 0 self._max_nodes = num_participants + 1 self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state) self._mock_state_holder.reset_mock()
def test_run_adds_to_participants_and_starts_last_call_if_min_nodes_is_reached( self) -> None: for num_participants in range(3): self._state = _RendezvousState() self._add_participants(num_participants, self._state) self._state.wait_list.add(self._node) expected_state = _RendezvousState() self._add_participants(num_participants, expected_state) expected_state.participants[self._node] = 0 expected_state.last_heartbeats[self._node] = self._now expected_state.deadline = self._now + self._timeout.last_call with self.subTest(num_participants=num_participants): self._min_nodes = num_participants + 1 self._max_nodes = num_participants + 2 self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state) self._mock_state_holder.reset_mock()
def setUp(self) -> None: self._node = _NodeDesc("this_node", 1, 1) self._min_nodes = 1 self._max_nodes = 2 self._keep_alive_interval = timedelta(seconds=30) self._state = _RendezvousState() self._state.participants[_NodeDesc("dummy1", 1, 1)] = 1 self._now = datetime(2000, 1, 1, hour=0, minute=0) self._deadline = 10 self._datetime_patch = patch( "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime") mock_datetime = self._datetime_patch.start() mock_datetime.utcnow.return_value = self._now self._time_patch = patch( "torch.distributed.elastic.rendezvous.dynamic_rendezvous.time") mock_time = self._time_patch.start() mock_time.monotonic.return_value = self._deadline
def test_encoded_size_is_within_expected_limit(self) -> None: state = _RendezvousState() state.round = 1 state.complete = True state.deadline = datetime.utcnow() state.closed = True # fmt: off expected_max_sizes = ( ( 5, 2 * (2 ** 10),), # 10 machines <= 2KB # noqa: E201, E241, E262 ( 50, 16 * (2 ** 10),), # 100 machines <= 16KB # noqa: E201, E241, E262 ( 500, 160 * (2 ** 10),), # 1000 machines <= 160KB # noqa: E201, E241, E262 (5000, 1600 * (2 ** 10),), # 10000 machines <= 1.6MB # noqa: E201, E241, E262 ) # fmt: on for num_nodes, max_byte_size in expected_max_sizes: with self.subTest(num_nodes=num_nodes, max_byte_size=max_byte_size): for i in range(num_nodes): node_running = _NodeDesc(f"dummy{i}.dummy1-dummy1-dummy1-dummy1.com", 12345, i) node_waiting = _NodeDesc(f"dummy{i}.dummy2-dummy2-dummy2-dummy2.com", 67890, i) state.participants[node_running] = i state.wait_list.add(node_waiting) state.last_heartbeats[node_running] = datetime.utcnow() state.last_heartbeats[node_waiting] = datetime.utcnow() bits = pickle.dumps(state) base64_bits = b64encode(bits) self.assertLessEqual(len(base64_bits), max_byte_size)
def test_run_removes_from_waitlist(self) -> None: self._state.wait_list.add(self._node) self._state.last_heartbeats[self._node] = self._now expected_state = _RendezvousState() self._assert_action(_Action.REMOVE_FROM_WAIT_LIST, expected_state)
def test_run_adds_to_waitlist(self) -> None: expected_state = _RendezvousState() expected_state.wait_list.add(self._node) expected_state.last_heartbeats[self._node] = self._now self._assert_action(_Action.ADD_TO_WAIT_LIST, expected_state)
def test_run_adds_to_participants(self) -> None: expected_state = _RendezvousState() expected_state.participants[self._node] = 0 expected_state.last_heartbeats[self._node] = self._now self._min_nodes = 2 self._max_nodes = 2 self._assert_action(_Action.ADD_TO_PARTICIPANTS, expected_state)
def test_run_removes_from_participants_and_moves_to_next_round_if_node_is_last_participant( self, ) -> None: self._state.participants[self._node] = 0 self._state.last_heartbeats[self._node] = self._now self._state.complete = True self._state.round = 1 expected_state = _RendezvousState() expected_state.complete = False expected_state.round = 2 self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state)
def test_run_removes_from_participants_and_clears_last_call_if_rendezvous_has_less_than_min_nodes( self, ) -> None: self._add_participants(2, self._state) self._state.participants[self._node] = 0 self._state.last_heartbeats[self._node] = self._now self._state.deadline = self._now expected_state = _RendezvousState() self._add_participants(2, expected_state) self._min_nodes = 3 self._max_nodes = 4 self._assert_action(_Action.REMOVE_FROM_PARTICIPANTS, expected_state)
def _create_state(self) -> _RendezvousState: state = _RendezvousState() state.round = 999 state.complete = True state.deadline = self._now state.closed = True state.participants = { _NodeDesc("dummy1", 1, 1): 0, _NodeDesc("dummy2", 1, 1): 1, _NodeDesc("dummy3", 1, 1): 2, } state.wait_list = { _NodeDesc("dummy4", 1, 1), _NodeDesc("dummy5", 1, 1), } state.last_heartbeats = { _NodeDesc("dummy1", 1, 1): self._now, _NodeDesc("dummy2", 1, 1): self._now - timedelta(seconds=15), _NodeDesc("dummy3", 1, 1): self._now - timedelta(seconds=30), _NodeDesc("dummy4", 1, 1): self._now - timedelta(seconds=60), _NodeDesc("dummy5", 1, 1): self._now - timedelta(seconds=90), } return state
def test_run_marks_rendezvous_closed(self) -> None: expected_state = _RendezvousState() expected_state.closed = True self._assert_action(_Action.MARK_RENDEZVOUS_CLOSED, expected_state)
def test_run_keeps_alive(self) -> None: expected_state = _RendezvousState() expected_state.last_heartbeats[self._node] = self._now self._assert_action(_Action.KEEP_ALIVE, expected_state)
def assert_state_empty(self, actual: _RendezvousState) -> None: self.assertDictEqual(vars(actual), vars(_RendezvousState()))
def __init__(self) -> None: self._state = _RendezvousState() self._dirty = None