def test_create_backend_raises_error_if_store_type_is_invalid(self) -> None: self._params.config["store_type"] = "dummy_store_type" with self.assertRaisesRegex( ValueError, r"^The store type must be 'tcp'. Other store types are not supported yet.$" ): create_backend(self._params)
def test_create_backend_raises_error_if_read_timeout_is_invalid(self) -> None: for read_timeout in ["0", "-10"]: with self.subTest(read_timeout=read_timeout): self._params.config["read_timeout"] = read_timeout with self.assertRaisesRegex( ValueError, r"^The read timeout must be a positive integer.$" ): create_backend(self._params)
def test_create_backend_raises_error_if_store_is_unreachable(self) -> None: self._params.config["is_host"] = "false" self._params.config["read_timeout"] = "2" with self.assertRaisesRegex( RendezvousConnectionError, r"^The connection to the C10d store has failed. See inner exception for details.$", ): create_backend(self._params)
def test_create_backend_raises_error_if_store_type_is_invalid( self) -> None: self._params.config["store_type"] = "dummy_store_type" with self.assertRaisesRegex( ValueError, r"^Invalid store type given. Currently only supports file and tcp.$" ): create_backend(self._params)
def test_create_backend_raises_error_if_file_path_is_invalid( self, filestore_mock) -> None: filestore_mock.side_effect = RuntimeError("test error") self._params_filestore.endpoint = "bad file path" with self.assertRaisesRegex( RendezvousConnectionError, r"^The connection to the C10d store has failed. See inner exception for " r"details.$", ): create_backend(self._params_filestore)
def test_create_backend_raises_error_if_tempfile_creation_fails( self, tempfile_mock) -> None: tempfile_mock.side_effect = OSError("test error") # Set the endpoint to empty so it defaults to creating a temp file self._params_filestore.endpoint = "" with self.assertRaisesRegex( RendezvousError, r"The file creation for C10d store has failed. See inner exception for details." ): create_backend(self._params_filestore)
def test_create_backend_raises_error_if_endpoint_is_invalid(self) -> None: for is_host in [True, False]: with self.subTest(is_host=is_host): self._params.config["is_host"] = str(is_host) self._params.endpoint = "dummy_endpoint" with self.assertRaisesRegex( RendezvousConnectionError, r"^The connection to the C10d store has failed. See inner exception for " r"details.$", ): create_backend(self._params)
def _assert_create_backend_returns_backend(self) -> None: backend, store = create_backend(self._params) self.assertEqual(backend.name, "c10d") self.assertIsInstance(store, self._expected_store_type) typecast_store = cast(self._expected_store_type, store) self.assertEqual( typecast_store.timeout, self._expected_read_timeout) # type: ignore[attr-defined] if (self._expected_store_type == TCPStore): self.assertEqual( typecast_store.host, self._expected_endpoint_host) # type: ignore[attr-defined] self.assertEqual( typecast_store.port, self._expected_endpoint_port) # type: ignore[attr-defined] if (self._expected_store_type == FileStore): if self._params.endpoint: self.assertEqual( typecast_store.path, self._expected_endpoint_file) # type: ignore[attr-defined] else: self.assertTrue( typecast_store.path.startswith( self._expected_temp_dir)) # type: ignore[attr-defined] backend.set_state(b"dummy_state") state = store.get("torch.rendezvous." + self._params.run_id) self.assertEqual(state, b64encode(b"dummy_state"))
def test_create_backend_returns_backend(self) -> None: backend = create_backend(self._params) self.assertIsInstance(backend.store, self._expected_store_type) self.assertEqual(backend.name, "c10d") self.assertEqual(backend.key, "torch.rendezvous." + self._params.run_id) store = backend.store self.assertEqual(store.host, self._expected_endpoint_host) # type: ignore[attr-defined] self.assertEqual(store.port, self._expected_endpoint_port) # type: ignore[attr-defined] self.assertEqual(store.timeout, self._expected_read_timeout) # type: ignore[attr-defined]
def test_create_backend_returns_backend(self) -> None: backend, store = create_backend(self._params) self.assertEqual(backend.name, "c10d") self.assertIsInstance(store, self._expected_store_type) tcp_store = cast(TCPStore, store) self.assertEqual(tcp_store.host, self._expected_endpoint_host) # type: ignore[attr-defined] self.assertEqual(tcp_store.port, self._expected_endpoint_port) # type: ignore[attr-defined] self.assertEqual(tcp_store.timeout, self._expected_read_timeout) # type: ignore[attr-defined] backend.set_state(b"dummy_state") state = store.get("torch.rendezvous." + self._params.run_id) self.assertEqual(state, b64encode(b"dummy_state"))