def test_encoded_size_is_within_expected_limit(self) -> None: state = _RendezvousState() state.round = 1 state.complete = True state.deadline = datetime.utcnow() state.closed = True # fmt: off expected_max_sizes = ( ( 5, 2 * (2 ** 10),), # 10 machines <= 2KB # noqa: E201, E241, E262 ( 50, 16 * (2 ** 10),), # 100 machines <= 16KB # noqa: E201, E241, E262 ( 500, 160 * (2 ** 10),), # 1000 machines <= 160KB # noqa: E201, E241, E262 (5000, 1600 * (2 ** 10),), # 10000 machines <= 1.6MB # noqa: E201, E241, E262 ) # fmt: on for num_nodes, max_byte_size in expected_max_sizes: with self.subTest(num_nodes=num_nodes, max_byte_size=max_byte_size): for i in range(num_nodes): node_running = _NodeDesc(f"dummy{i}.dummy1-dummy1-dummy1-dummy1.com", 12345, i) node_waiting = _NodeDesc(f"dummy{i}.dummy2-dummy2-dummy2-dummy2.com", 67890, i) state.participants[node_running] = i state.wait_list.add(node_waiting) state.last_heartbeats[node_running] = datetime.utcnow() state.last_heartbeats[node_waiting] = datetime.utcnow() bits = pickle.dumps(state) base64_bits = b64encode(bits) self.assertLessEqual(len(base64_bits), max_byte_size)
def setUp(self) -> None: self._node = _NodeDesc("this_node", 1, 1) self._min_nodes = 1 self._max_nodes = 2 self._keep_alive_interval = timedelta(seconds=30) self._state = _RendezvousState() self._state.participants[_NodeDesc("dummy1", 1, 1)] = 1 self._now = datetime(2000, 1, 1, hour=0, minute=0) self._deadline = 10 self._datetime_patch = patch( "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime") mock_datetime = self._datetime_patch.start() mock_datetime.utcnow.return_value = self._now self._time_patch = patch( "torch.distributed.elastic.rendezvous.dynamic_rendezvous.time") mock_time = self._time_patch.start() mock_time.monotonic.return_value = self._deadline
def test_sync_sanitizes_state(self) -> None: expected_state = self._create_state() state = copy.deepcopy(expected_state) dead_node1 = _NodeDesc("dead1", 1, 1) dead_node2 = _NodeDesc("dead2", 1, 1) dead_node3 = _NodeDesc("dead3", 1, 1) dead_node4 = _NodeDesc("dead4", 1, 1) dead_node5 = _NodeDesc("dead5", 1, 1) state.last_heartbeats[dead_node1] = self._now - timedelta(seconds=91) state.last_heartbeats[dead_node2] = self._now - timedelta(seconds=100) state.last_heartbeats[dead_node3] = self._now - timedelta(seconds=110) state.last_heartbeats[dead_node4] = self._now - timedelta(seconds=120) state.last_heartbeats[dead_node5] = self._now - timedelta(seconds=130) state.participants[dead_node1] = 0 state.participants[dead_node2] = 0 state.participants[dead_node3] = 0 state.wait_list.add(dead_node4) state.wait_list.add(dead_node5) self._backend.set_state_internal(state) state_holder = self._create_state_holder() state_holder.sync() self.assert_state_equal(state_holder.state, expected_state)
def test_hash(self) -> None: desc1 = _NodeDesc("dummy_fqdn", 2, 4) desc2 = _NodeDesc("dummy_fqdn", 3, 5) descs = {desc1, desc2} self.assertIn(desc1, descs) self.assertIn(desc2, descs)
def test_num_nodes_waiting_returns_expected_value(self) -> None: self._state.wait_list.add(_NodeDesc("dummy1", 1, 1)) self._state.wait_list.add(_NodeDesc("dummy2", 1, 1)) handler = self._create_handler() self.assertEqual(handler.num_nodes_waiting(), 2) self._mock_sync.assert_called_once()
def _add_participants( self, num_participants: int, state: _RendezvousState, ranked: bool = False ) -> None: for i in range(num_participants): if ranked: node = _NodeDesc(f"dummy{i}", 1, 1) rank = i else: node = _NodeDesc(f"dummy{num_participants - i - 1}", 1, 1) # Add in reverse. rank = 0 state.participants[node] = rank state.last_heartbeats[node] = self._now
def setUp(self) -> None: self._node = _NodeDesc("this_node", 1, 1) self._state_holder = FakeRendezvousStateHolder() mock_sync = MagicMock(wraps=self._state_holder.sync) mock_mark = MagicMock(wraps=self._state_holder.mark_dirty) self._mock_state_holder = Mock() self._mock_state_holder.sync = mock_sync self._mock_state_holder.mark = mock_mark setattr(self._state_holder, "sync", mock_sync) # noqa: B010 setattr(self._state_holder, "mark_dirty", mock_mark) # noqa: B010 self._state = self._state_holder.state self._min_nodes = 1 self._max_nodes = 1 self._timeout = RendezvousTimeout() self._now = datetime(2000, 1, 1, hour=0, minute=0) self._datetime_patch = patch( "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime") mock_datetime = self._datetime_patch.start() mock_datetime.utcnow.return_value = self._now
def setUp(self) -> None: self._node = _NodeDesc("this_node", 1, 1) self._min_nodes = 1 self._max_nodes = 1 self._join_timeout: Optional[timedelta] = None self._close_timeout: Optional[timedelta] = None self._heartbeat_timeout: Optional[timedelta] = None self._keep_alive_interval = timedelta(seconds=30) self._store = DummyStore() self._mock_store_get = MagicMock(return_value=b"dummy_value") setattr(self._store, "get", self._mock_store_get) # noqa: B010 self._state_holder = FakeRendezvousStateHolder() self._mock_sync = MagicMock(wraps=self._state_holder.sync) setattr(self._state_holder, "sync", self._mock_sync) # noqa: B010 self._state = self._state_holder.state
def test_next_rendezvous_returns_expected_value(self) -> None: self._state.participants[_NodeDesc("dummy1", 1, 1)] = 0 self._state.participants[_NodeDesc("dummy2", 1, 1)] = 0 self._max_nodes = 3 handler = self._create_handler() store, rank, world_size = handler.next_rendezvous() self.assertEqual(rank, 2) self.assertEqual(world_size, 3) _ = store.get("dummy_key") self._mock_store_get.assert_called_once_with("torch.rendezvous.dummy_run_id.0/dummy_key")
def _create_state(self) -> _RendezvousState: state = _RendezvousState() state.round = 999 state.complete = True state.deadline = self._now state.closed = True state.participants = { _NodeDesc("dummy1", 1, 1): 0, _NodeDesc("dummy2", 1, 1): 1, _NodeDesc("dummy3", 1, 1): 2, } state.wait_list = { _NodeDesc("dummy4", 1, 1), _NodeDesc("dummy5", 1, 1), } state.last_heartbeats = { _NodeDesc("dummy1", 1, 1): self._now, _NodeDesc("dummy2", 1, 1): self._now - timedelta(seconds=15), _NodeDesc("dummy3", 1, 1): self._now - timedelta(seconds=30), _NodeDesc("dummy4", 1, 1): self._now - timedelta(seconds=60), _NodeDesc("dummy5", 1, 1): self._now - timedelta(seconds=90), } return state
def test_keep_alive_thread_is_started_with_next_rendezvous_and_stopped_with_finalizer( self, ) -> None: self._node = _NodeDesc("this_node", 1, 3) name = "RendezvousKeepAliveTimer_3" handler = self._create_handler() self.assertTrue(all(t.name != name for t in threading.enumerate())) handler.next_rendezvous() self.assertTrue(any(t.name == name for t in threading.enumerate())) del handler self.assertTrue(all(t.name != name for t in threading.enumerate()))
def test_repr(self) -> None: desc = _NodeDesc("dummy_fqdn", 3, 5) self.assertEqual(repr(desc), "dummy_fqdn_3_5")