def test_encoded_size_is_within_expected_limit(self) -> None:
        state = _RendezvousState()
        state.round = 1
        state.complete = True
        state.deadline = datetime.utcnow()
        state.closed = True

        # fmt: off
        expected_max_sizes = (
            (   5,    2 * (2 ** 10),),  #    10 machines <=   2KB  # noqa: E201, E241, E262
            (  50,   16 * (2 ** 10),),  #   100 machines <=  16KB  # noqa: E201, E241, E262
            ( 500,  160 * (2 ** 10),),  #  1000 machines <= 160KB  # noqa: E201, E241, E262
            (5000, 1600 * (2 ** 10),),  # 10000 machines <= 1.6MB  # noqa: E201, E241, E262
        )
        # fmt: on

        for num_nodes, max_byte_size in expected_max_sizes:
            with self.subTest(num_nodes=num_nodes, max_byte_size=max_byte_size):
                for i in range(num_nodes):
                    node_running = _NodeDesc(f"dummy{i}.dummy1-dummy1-dummy1-dummy1.com", 12345, i)
                    node_waiting = _NodeDesc(f"dummy{i}.dummy2-dummy2-dummy2-dummy2.com", 67890, i)

                    state.participants[node_running] = i

                    state.wait_list.add(node_waiting)

                    state.last_heartbeats[node_running] = datetime.utcnow()
                    state.last_heartbeats[node_waiting] = datetime.utcnow()

                bits = pickle.dumps(state)

                base64_bits = b64encode(bits)

                self.assertLessEqual(len(base64_bits), max_byte_size)
    def setUp(self) -> None:
        self._node = _NodeDesc("this_node", 1, 1)

        self._min_nodes = 1
        self._max_nodes = 2

        self._keep_alive_interval = timedelta(seconds=30)

        self._state = _RendezvousState()
        self._state.participants[_NodeDesc("dummy1", 1, 1)] = 1

        self._now = datetime(2000, 1, 1, hour=0, minute=0)

        self._deadline = 10

        self._datetime_patch = patch(
            "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime")

        mock_datetime = self._datetime_patch.start()
        mock_datetime.utcnow.return_value = self._now

        self._time_patch = patch(
            "torch.distributed.elastic.rendezvous.dynamic_rendezvous.time")

        mock_time = self._time_patch.start()
        mock_time.monotonic.return_value = self._deadline
    def test_sync_sanitizes_state(self) -> None:
        expected_state = self._create_state()

        state = copy.deepcopy(expected_state)

        dead_node1 = _NodeDesc("dead1", 1, 1)
        dead_node2 = _NodeDesc("dead2", 1, 1)
        dead_node3 = _NodeDesc("dead3", 1, 1)
        dead_node4 = _NodeDesc("dead4", 1, 1)
        dead_node5 = _NodeDesc("dead5", 1, 1)

        state.last_heartbeats[dead_node1] = self._now - timedelta(seconds=91)
        state.last_heartbeats[dead_node2] = self._now - timedelta(seconds=100)
        state.last_heartbeats[dead_node3] = self._now - timedelta(seconds=110)
        state.last_heartbeats[dead_node4] = self._now - timedelta(seconds=120)
        state.last_heartbeats[dead_node5] = self._now - timedelta(seconds=130)

        state.participants[dead_node1] = 0
        state.participants[dead_node2] = 0
        state.participants[dead_node3] = 0

        state.wait_list.add(dead_node4)
        state.wait_list.add(dead_node5)

        self._backend.set_state_internal(state)

        state_holder = self._create_state_holder()

        state_holder.sync()

        self.assert_state_equal(state_holder.state, expected_state)
    def test_hash(self) -> None:
        desc1 = _NodeDesc("dummy_fqdn", 2, 4)
        desc2 = _NodeDesc("dummy_fqdn", 3, 5)

        descs = {desc1, desc2}

        self.assertIn(desc1, descs)
        self.assertIn(desc2, descs)
    def test_num_nodes_waiting_returns_expected_value(self) -> None:
        self._state.wait_list.add(_NodeDesc("dummy1", 1, 1))
        self._state.wait_list.add(_NodeDesc("dummy2", 1, 1))

        handler = self._create_handler()

        self.assertEqual(handler.num_nodes_waiting(), 2)

        self._mock_sync.assert_called_once()
    def _add_participants(
        self, num_participants: int, state: _RendezvousState, ranked: bool = False
    ) -> None:
        for i in range(num_participants):
            if ranked:
                node = _NodeDesc(f"dummy{i}", 1, 1)
                rank = i
            else:
                node = _NodeDesc(f"dummy{num_participants - i - 1}", 1, 1)  # Add in reverse.
                rank = 0

            state.participants[node] = rank

            state.last_heartbeats[node] = self._now
    def setUp(self) -> None:
        self._node = _NodeDesc("this_node", 1, 1)

        self._state_holder = FakeRendezvousStateHolder()

        mock_sync = MagicMock(wraps=self._state_holder.sync)
        mock_mark = MagicMock(wraps=self._state_holder.mark_dirty)

        self._mock_state_holder = Mock()
        self._mock_state_holder.sync = mock_sync
        self._mock_state_holder.mark = mock_mark

        setattr(self._state_holder, "sync", mock_sync)  # noqa: B010
        setattr(self._state_holder, "mark_dirty", mock_mark)  # noqa: B010

        self._state = self._state_holder.state

        self._min_nodes = 1
        self._max_nodes = 1

        self._timeout = RendezvousTimeout()

        self._now = datetime(2000, 1, 1, hour=0, minute=0)

        self._datetime_patch = patch(
            "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime")

        mock_datetime = self._datetime_patch.start()
        mock_datetime.utcnow.return_value = self._now
    def setUp(self) -> None:
        self._node = _NodeDesc("this_node", 1, 1)

        self._min_nodes = 1
        self._max_nodes = 1

        self._join_timeout: Optional[timedelta] = None
        self._close_timeout: Optional[timedelta] = None
        self._heartbeat_timeout: Optional[timedelta] = None

        self._keep_alive_interval = timedelta(seconds=30)

        self._store = DummyStore()

        self._mock_store_get = MagicMock(return_value=b"dummy_value")

        setattr(self._store, "get", self._mock_store_get)  # noqa: B010

        self._state_holder = FakeRendezvousStateHolder()

        self._mock_sync = MagicMock(wraps=self._state_holder.sync)

        setattr(self._state_holder, "sync", self._mock_sync)  # noqa: B010

        self._state = self._state_holder.state
    def test_next_rendezvous_returns_expected_value(self) -> None:
        self._state.participants[_NodeDesc("dummy1", 1, 1)] = 0
        self._state.participants[_NodeDesc("dummy2", 1, 1)] = 0

        self._max_nodes = 3

        handler = self._create_handler()

        store, rank, world_size = handler.next_rendezvous()

        self.assertEqual(rank, 2)
        self.assertEqual(world_size, 3)

        _ = store.get("dummy_key")

        self._mock_store_get.assert_called_once_with("torch.rendezvous.dummy_run_id.0/dummy_key")
    def _create_state(self) -> _RendezvousState:
        state = _RendezvousState()
        state.round = 999
        state.complete = True
        state.deadline = self._now
        state.closed = True
        state.participants = {
            _NodeDesc("dummy1", 1, 1): 0,
            _NodeDesc("dummy2", 1, 1): 1,
            _NodeDesc("dummy3", 1, 1): 2,
        }
        state.wait_list = {
            _NodeDesc("dummy4", 1, 1),
            _NodeDesc("dummy5", 1, 1),
        }
        state.last_heartbeats = {
            _NodeDesc("dummy1", 1, 1): self._now,
            _NodeDesc("dummy2", 1, 1): self._now - timedelta(seconds=15),
            _NodeDesc("dummy3", 1, 1): self._now - timedelta(seconds=30),
            _NodeDesc("dummy4", 1, 1): self._now - timedelta(seconds=60),
            _NodeDesc("dummy5", 1, 1): self._now - timedelta(seconds=90),
        }

        return state
    def test_keep_alive_thread_is_started_with_next_rendezvous_and_stopped_with_finalizer(
        self, ) -> None:
        self._node = _NodeDesc("this_node", 1, 3)

        name = "RendezvousKeepAliveTimer_3"

        handler = self._create_handler()

        self.assertTrue(all(t.name != name for t in threading.enumerate()))

        handler.next_rendezvous()

        self.assertTrue(any(t.name == name for t in threading.enumerate()))

        del handler

        self.assertTrue(all(t.name != name for t in threading.enumerate()))
    def test_repr(self) -> None:
        desc = _NodeDesc("dummy_fqdn", 3, 5)

        self.assertEqual(repr(desc), "dummy_fqdn_3_5")