def test_static_rdzv_multiple_calls(self): sock = get_socket_with_port() with closing(sock): master_port = sock.getsockname()[1] master_addr = "localhost" rdzv_params = RendezvousParameters( backend="static", endpoint=f"{master_addr}:{master_port}", run_id="test_id", min_nodes=1, max_nodes=1, rank=0, ) rdzv_handler = create_rdzv_handler(rdzv_params) # Call rendezvous two times store, rank, world_size = rdzv_handler.next_rendezvous() self.assertIsNotNone(store) self.assertEqual(0, rank) self.assertEqual(1, world_size) store, rank, world_size = rdzv_handler.next_rendezvous() self.assertIsNotNone(store) self.assertEqual(0, rank) self.assertEqual(1, world_size)
def test_launch_dist_sum_with_static_rdzv(self): nnodes = 1 nproc_per_node = 4 sock = get_socket_with_port() with closing(sock): master_port = sock.getsockname()[1] rdzv_endpoint = f"127.0.0.1:{master_port}" rank = 0 rdzv_config = { "rank": rank, } res = elastic_launch( get_test_launch_config( rdzv_endpoint, nnodes, nnodes, nproc_per_node, rdzv_backend="static", config=rdzv_config, ), _dist_sum, )() expected_res = [sum(range(nproc_per_node))] * nproc_per_node actual_res = sorted(value for value in res.values()) self.assertEqual(expected_res, actual_res)
def test_launch_user_script_python_caffe2_bc(self): nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node sock = get_socket_with_port() with closing(sock): master_port = sock.getsockname()[1] args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--monitor_interval=1", "--start_method=fork", "--master_addr=localhost", f"--master_port={master_port}", "--node_rank=0", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] launch.main(args) # make sure all the workers ran # each worker touches a file with its global rank as the name self.assertSetEqual({str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)))
def test_launch_without_env(self): nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node sock = get_socket_with_port() with closing(sock): master_port = sock.getsockname()[1] args = [ f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--monitor_interval=1", "--start_method=fork", "--master_addr=localhost", f"--master_port={master_port}", "--node_rank=0", path("bin/test_script_local_rank.py"), ] launch.main(args)