def _get_worker_spec( self, max_restarts=1, monitor_interval=1.0, role="test_trainer", local_world_size=8, ): run_id = str(uuid.uuid4().int) port = get_free_port() endpoint = f"127.0.0.1:{port}" rdzv_params = RendezvousParameters( backend="static", endpoint=endpoint, run_id=run_id, min_nodes=1, max_nodes=1, rank=0, ) rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params) spec = WorkerSpec( role=role, local_world_size=local_world_size, fn=do_nothing, args=(), rdzv_handler=rdzv_handler, max_restarts=max_restarts, monitor_interval=monitor_interval, ) return spec
def test_create_store_timeout_on_worker(self): with self.assertRaises(TimeoutError): port = get_free_port() create_c10d_store( is_server=False, server_addr=socket.gethostname(), server_port=port, world_size=2, timeout=1, )
def test_init_method_env_with_torchelastic(self): port = get_free_port() launch.main([ "--run_path", "--nnodes=1", "--nproc_per_node=4", "--master_addr=localhost", f"--master_port={port}", "--monitor_interval=1", path("bin/test_script_init_method.py"), "--init_method=env://", ])
def test_init_method_tcp(self): port = get_free_port() with patch.object( sys, "argv", [ path("bin/test_script_init_method.py"), f"--init_method=tcp://localhost:{port}", "--rank=0", "--world_size=1", ], ): runpy.run_path(sys.argv[0], run_name="__main__")
def test_init_method_env(self): port = get_free_port() with patch.dict( os.environ, { "RANK": "0", "WORLD_SIZE": "1", "MASTER_ADDR": "localhost", "MASTER_PORT": str(port), }, ), patch.object( sys, "argv", [ path("bin/test_script_init_method.py"), "--init_method=env://", ], ): runpy.run_path(sys.argv[0], run_name="__main__")
def test_create_store_multi(self): world_size = 3 server_port = get_free_port() localhost = socket.gethostname() worker0 = mp.Process( target=_create_c10d_store_mp, args=(False, localhost, server_port, world_size), ) worker1 = mp.Process( target=_create_c10d_store_mp, args=(False, localhost, server_port, world_size), ) worker0.start() worker1.start() # start the server on the main process store = create_c10d_store( is_server=True, server_addr=localhost, server_port=server_port, world_size=world_size, timeout=2, ) worker0.join() worker1.join() # check test_key/pid == "test_value" self.assertEqual( "test_value", store.get(f"test_key/{worker0.pid}").decode("UTF-8") ) self.assertEqual( "test_value", store.get(f"test_key/{worker1.pid}").decode("UTF-8") ) self.assertEqual(0, worker0.exitcode) self.assertEqual(0, worker1.exitcode)