def test_parse_rendezvous_endpoint_raises_error_if_port_is_too_big(self) -> None: endpoints = ["dummy.com:65536", "dummy.com:70000"] for endpoint in endpoints: with self.subTest(endpoint=endpoint): with self.assertRaisesRegex( ValueError, rf"^The port number of the rendezvous endpoint '{endpoint}' must be an integer " r"between 0 and 65536.$", ): parse_rendezvous_endpoint(endpoint, default_port=123)
def test_parse_rendezvous_endpoint_raises_error_if_hostname_is_invalid( self, ) -> None: endpoints = ["~", "dummy.com :123", "~:123", ":123"] for endpoint in endpoints: with self.subTest(endpoint=endpoint): with self.assertRaisesRegex( ValueError, rf"^The hostname of the rendezvous endpoint '{endpoint}' must be a " r"dot-separated list of labels, an IPv4 address, or an IPv6 address.$", ): parse_rendezvous_endpoint(endpoint, default_port=123)
def test_parse_rendezvous_endpoint_returns_tuple(self) -> None: endpoints = [ "dummy.com:0", "dummy.com:123", "dummy.com:65535", "dummy-1.com:0", "dummy-1.com:123", "dummy-1.com:65535", "123.123.123.123:0", "123.123.123.123:123", "123.123.123.123:65535", "[2001:db8::1]:0", "[2001:db8::1]:123", "[2001:db8::1]:65535", ] for endpoint in endpoints: with self.subTest(endpoint=endpoint): host, port = parse_rendezvous_endpoint(endpoint, default_port=123) expected_host, expected_port = endpoint.rsplit(":", 1) if expected_host[0] == "[" and expected_host[-1] == "]": expected_host = expected_host[1:-1] self.assertEqual(host, expected_host) self.assertEqual(port, int(expected_port))
def test_parse_rendezvous_endpoint_returns_tuple_if_endpoint_is_empty(self) -> None: endpoints = ["", " "] for endpoint in endpoints: with self.subTest(endpoint=endpoint): host, port = parse_rendezvous_endpoint("", default_port=123) self.assertEqual(host, "localhost") self.assertEqual(port, 123)
def test_parse_rendezvous_endpoint_returns_tuple_if_endpoint_has_no_port( self, ) -> None: endpoints = ["dummy.com", "dummy-1.com", "123.123.123.123", "[2001:db8::1]"] for endpoint in endpoints: with self.subTest(endpoint=endpoint): host, port = parse_rendezvous_endpoint(endpoint, default_port=123) expected_host = endpoint if expected_host[0] == "[" and expected_host[-1] == "]": expected_host = expected_host[1:-1] self.assertEqual(host, expected_host) self.assertEqual(port, 123)
def _get_addr_and_port( rdzv_parameters: RendezvousParameters, ) -> Tuple[Optional[str], Optional[int]]: if rdzv_parameters.backend != "static": return (None, None) endpoint = rdzv_parameters.endpoint endpoint = endpoint.strip() if not endpoint: raise ValueError( "Endpoint is missing in endpoint. Try to add --master_addr and --master_port" ) master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1) if master_port == -1: raise ValueError( f"port is missing in endpoint: {endpoint}. Try to specify --master_port" ) return (master_addr, master_port)
def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: if "rank" not in params.config: raise ValueError("rank is absent in RendezvousParameters." "Try add --node_rank to the cmd request") endpoint = params.endpoint.strip() if not endpoint: raise ValueError( "endpoint is absent in RendezvousParameters" "Try add --master_port and --master_addr to the cmd request") master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1) if master_port == -1: raise ValueError( f"Port is absent in endpoint: {endpoint}. Try launching with --master_port" ) world_size = params.max_nodes rank = cast(int, params.config.get("rank")) run_id = params.run_id if "timeout" in params.config: timeout = int(params.config["timeout"]) else: timeout = _default_timeout_seconds return StaticTCPRendezvous(master_addr, master_port, rank, world_size, run_id, timeout)