예제 #1
0
 def test_ipv6_addr(self):
     rdzv_params = RendezvousParameters(
         backend="static",
         endpoint="[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:90",
         run_id="test_id",
         min_nodes=1,
         max_nodes=1,
     )
     with self.assertRaises(ValueError):
         create_rdzv_handler(rdzv_params)
예제 #2
0
 def test_ipv6_addr_localhost(self):
     rdzv_params = RendezvousParameters(
         backend="static",
         endpoint="[::1]:90",
         run_id="test_id",
         min_nodes=1,
         max_nodes=1,
     )
     with self.assertRaises(ValueError):
         create_rdzv_handler(rdzv_params)
예제 #3
0
 def test_empty_endpoint(self):
     rdzv_params = RendezvousParameters(
         backend="static",
         endpoint="",
         run_id="test_id",
         min_nodes=1,
         max_nodes=1,
     )
     with self.assertRaises(ValueError):
         create_rdzv_handler(rdzv_params)
예제 #4
0
    def test_static_rdzv_multiple_calls(self):
        sock = get_socket_with_port()
        with closing(sock):
            master_port = sock.getsockname()[1]
        master_addr = "localhost"

        rdzv_params = RendezvousParameters(
            backend="static",
            endpoint=f"{master_addr}:{master_port}",
            run_id="test_id",
            min_nodes=1,
            max_nodes=1,
            rank=0,
        )
        rdzv_handler = create_rdzv_handler(rdzv_params)

        # Call rendezvous two times
        store, rank, world_size = rdzv_handler.next_rendezvous()
        self.assertIsNotNone(store)
        self.assertEqual(0, rank)
        self.assertEqual(1, world_size)

        store, rank, world_size = rdzv_handler.next_rendezvous()
        self.assertIsNotNone(store)
        self.assertEqual(0, rank)
        self.assertEqual(1, world_size)
예제 #5
0
    def test_get_backend(self):
        rdzv_params = RendezvousParameters(
            backend="static",
            endpoint="localhost:123",
            run_id="test",
            min_nodes=1,
            max_nodes=1,
            timeout=60,
            rank=0,
        )

        static_rdzv = create_rdzv_handler(rdzv_params)
        self.assertEqual("static", static_rdzv.get_backend())