def setUp(self) -> None: self._node = _NodeDesc("this_node", 1, 1) self._state_holder = FakeRendezvousStateHolder() mock_sync = MagicMock(wraps=self._state_holder.sync) mock_mark = MagicMock(wraps=self._state_holder.mark_dirty) self._mock_state_holder = Mock() self._mock_state_holder.sync = mock_sync self._mock_state_holder.mark = mock_mark setattr(self._state_holder, "sync", mock_sync) # noqa: B010 setattr(self._state_holder, "mark_dirty", mock_mark) # noqa: B010 self._state = self._state_holder.state self._min_nodes = 1 self._max_nodes = 1 self._timeout = RendezvousTimeout() self._now = datetime(2000, 1, 1, hour=0, minute=0) self._datetime_patch = patch( "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime") mock_datetime = self._datetime_patch.start() mock_datetime.utcnow.return_value = self._now
def test_init_initializes_timeout_if_no_timeout_is_specified(self) -> None: timeout = RendezvousTimeout() self.assertEqual(timeout.join, timedelta(seconds=600)) self.assertEqual(timeout.last_call, timedelta(seconds=30)) self.assertEqual(timeout.close, timedelta(seconds=30)) self.assertEqual(timeout.heartbeat, timedelta(seconds=5))
def setUp(self) -> None: self._run_id = "dummy_run_id" self._store = DummyStore() self._backend = DummyRendezvousBackend() self._min_nodes = 3 self._max_nodes = 6 self._timeout: Optional[RendezvousTimeout] = RendezvousTimeout()
def setUp(self) -> None: self._backend = FakeRendezvousBackend() mock_get_state = MagicMock(wraps=self._backend.get_state) mock_set_state = MagicMock(wraps=self._backend.set_state) self._mock_backend = Mock() self._mock_backend.get_state = mock_get_state self._mock_backend.set_state = mock_set_state setattr(self._backend, "get_state", mock_get_state) # noqa: B010 setattr(self._backend, "set_state", mock_set_state) # noqa: B010 self._settings = RendezvousSettings( run_id="dummy_run_id", min_nodes=1, max_nodes=1, timeout=RendezvousTimeout(), keep_alive_interval=timedelta(seconds=30), keep_alive_max_attempt=3, ) self._cache_duration = 0 self._now = datetime(2000, 1, 1, hour=0, minute=0) self._datetime_patch = patch( "torch.distributed.elastic.rendezvous.dynamic_rendezvous.datetime") mock_datetime = self._datetime_patch.start() mock_datetime.utcnow.return_value = self._now
def test_create_handler_returns_handler_if_timeout_is_not_specified(self) -> None: del self._params.config["join_timeout"] del self._params.config["last_call_timeout"] del self._params.config["close_timeout"] self._expected_timeout = RendezvousTimeout() self.test_create_handler_returns_handler()
def test_init_raises_error_if_timeout_is_not_positive(self) -> None: join_timeouts = [timedelta(seconds=0), timedelta(seconds=-1)] for join_timeout in join_timeouts: with self.subTest(join_timeout=join_timeout): with self.assertRaisesRegex( ValueError, rf"^The join timeout \({join_timeout}\) must be positive.$" ): timeout = RendezvousTimeout(join_timeout)
def test_init_initializes_timeout(self) -> None: timeout = RendezvousTimeout( timedelta(seconds=50), timedelta(seconds=60), timedelta(seconds=70), ) self.assertEqual(timeout.join, timedelta(seconds=50)) self.assertEqual(timeout.last_call, timedelta(seconds=60)) self.assertEqual(timeout.close, timedelta(seconds=70))
def _get_next_action(self) -> _Action: op = self._create_op() settings = RendezvousSettings( run_id="dummy_run_id", min_nodes=self._min_nodes, max_nodes=self._max_nodes, timeout=RendezvousTimeout(), keep_alive_interval=self._keep_alive_interval, keep_alive_max_attempt=3, ) ctx = _RendezvousContext(self._node, self._state, settings) return op(ctx, self._deadline)
def _create_handler(self) -> DynamicRendezvousHandler: settings = RendezvousSettings( run_id="dummy_run_id", min_nodes=self._min_nodes, max_nodes=self._max_nodes, timeout=RendezvousTimeout( join=self._join_timeout, close=self._close_timeout, heartbeat=self._heartbeat_timeout, ), keep_alive_interval=self._keep_alive_interval, keep_alive_max_attempt=3, ) self._state_holder.state = self._state return DynamicRendezvousHandler(self._node, settings, "dummy_backend", self._store, self._state_holder)
def setUp(self) -> None: self._store = DummyStore() self._backend = DummyRendezvousBackend() self._params = RendezvousParameters( backend=self._backend.name, endpoint="dummy_endpoint", run_id="dummy_run_id", min_nodes=3, max_nodes=6, join_timeout="50", last_call_timeout="60", close_timeout="70", ) self._expected_timeout = RendezvousTimeout(timedelta(seconds=50), timedelta(seconds=60), timedelta(seconds=70))