def init_rpc_connection(self, global_rank: int, world_size: int) -> None: os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000') rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size) rpc._set_rpc_timeout(self.rpc_timeout_sec) self._is_rpc_initialized = True
def test_timeout_in_torchscript_function(self): # Call rpc_async + fut.wait() in torchscript function and ensure that # timeout is raised. if self.rank != 0: return dst_worker_name = worker_name((self.rank + 1) % self.world_size) args = (torch.tensor([1, 1]), torch.tensor([2, 2])) kwargs = { "first_kwarg": torch.tensor([2, 2]), "second_kwarg": torch.tensor([3, 3]), } expected_error = self.get_timeout_error_regex() # Ensure that we get a timeout if we override the default timeout and # the RPC takes longer to execute. with self.assertRaisesRegex(RuntimeError, expected_error): rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0.5) # Ensure that we timeout if we don't specify a timeout but the default # is less than the RPC takes to execute. rpc._set_rpc_timeout(0.001) with self.assertRaisesRegex(RuntimeError, expected_error): script_rpc_async_call( dst_worker_name, args, kwargs ) # Ensure that we run to completion if zero timeout is specified. ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0) self.assertEqual(ret, torch.tensor([8, 8])) # reset for clean shutdown rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
def test_rpc_timeouts(self): dst_rank = (self.rank + 1) % self.world_size rpc._set_rpc_timeout(timedelta(milliseconds=1)) # futures should time out and be marked with an exception indicating it as such. futs = [ rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=()) for _ in range(10) ] for fut in futs: with self.assertRaisesRegex(RuntimeError, "RPC ran for more than"): fut.wait() # ensure that if a new timeout is set old futures don't time out but new ones do. rpc._set_rpc_timeout(timedelta(seconds=200)) # create a longstanding RPC. fut1 = rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=(1, )) # now, set a short timeout. rpc._set_rpc_timeout(timedelta(milliseconds=1)) # f2 should time out, f should not. fut2 = rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=(1, )) with self.assertRaises(RuntimeError): fut2.wait() fut1.wait() # future should run to completion if the timeout is zero. rpc._set_rpc_timeout(timedelta(seconds=0)) rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=()).wait() # reset to default timeout so shutdown messages can process cleanly. rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT)
def test_timeout_in_python(self): # Ensures timeouts are raised if we call rpc_async from within a # torchscript function, but wait on the future in python. if self.rank != 0: return dst_worker_name = worker_name((self.rank + 1) % self.world_size) args = (torch.tensor([1, 1]), torch.tensor([2, 2])) kwargs = { "first_kwarg": torch.tensor([2, 2]), "second_kwarg": torch.tensor([3, 3]), } expected_error = self.get_timeout_error_regex() fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0.5) with self.assertRaisesRegex(RuntimeError, expected_error): fut.wait() # Ensure timeout if we don't specify but the default is less than the # RPC takes to execute. rpc._set_rpc_timeout(0.001) fut = rpc_async_call_future_ret(dst_worker_name, args, kwargs) with self.assertRaisesRegex(RuntimeError, expected_error): fut.wait() # Ensure run to completion if zero timeout is specified fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0) result = fut.wait() self.assertEqual(result, torch.tensor([8, 8])) # reset for clean shutdown rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
def run_parameter_server(rank, world_size): # The parameter server just acts as a host for the model and responds to # requests from trainers. # rpc.shutdown() will wait for all workers to complete by default, which # in this case means that the parameter server will wait for all trainers # to complete, and then exit. print('PS master initializing RPC') rpc.init_rpc(name='parameter_server', rank=rank, world_size=world_size) rpc._set_rpc_timeout(timedelta(seconds=60)) print('RPC initialized! Running parameter server...') rpc.shutdown(graceful=True) print('RPC shutdown on parameter server.')
def run_worker(rank, world_size, epochs, batch_size, learning_rate, input_dim, hidden_dim, layer_dim, output_dim, train_set, validation_set, test_set): print(f'Worker rank {rank} initializing RPC') rpc.init_rpc( name=f'trainer_{rank}', rank=rank, world_size=world_size) rpc._set_rpc_timeout(timedelta(seconds=60)) print(f'Worker {rank} done initializing RPC') model = WorkerNetwork(input_dim, hidden_dim, layer_dim, output_dim) worker = ParameterWorkerTrainer(rank, world_size, model, train_set, batch_size, learning_rate, validation_set, test_set) worker.train(epochs) rpc.shutdown()
def test_rpc_builtin_timeout(self): next_rank = (self.rank + 1) % self.world_size dst_worker = worker_name(next_rank) expected_error = self.get_timeout_error_regex() # PYTHON_CALL message types which correspond to Python UDF over RPC # by default get a delay (see faulty_rpc_agent_test_fixture) with self.assertRaisesRegex(RuntimeError, expected_error): rpc.rpc_sync( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1, ) fut = rpc.rpc_async( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1 ) with self.assertRaisesRegex(RuntimeError, expected_error): fut.wait() # Ensure that the currently set default timeout is large enough such # that RPCs with delays still complete. fut = rpc.rpc_async( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) fut.wait() # Ensure timeout if we set a new default and don't override rpc._set_rpc_timeout(0.001) fut = rpc.rpc_async( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) with self.assertRaisesRegex(RuntimeError, expected_error): fut.wait() # Ensure run to completion if we specify timeout of 0 fut = rpc.rpc_async( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=0 ) fut.wait() # Reset for clean shutdown rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
def test_rpc_script_timeout(self): next_rank = (self.rank + 1) % self.world_size dst_worker = worker_name(next_rank) expected_error = self.get_timeout_error_regex() with self.assertRaisesRegex(RuntimeError, expected_error): rpc.rpc_sync(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1) fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1) with self.assertRaisesRegex(RuntimeError, expected_error): fut.wait() # Ensure that the currently set default timeout is large enough such # that RPCs with delays still complete. fut = rpc.rpc_async( dst_worker, my_script_func, args=(torch.tensor(1),) ) fut.wait() # Ensure timeout if we set a new default and don't override rpc._set_rpc_timeout(0.001) fut = rpc.rpc_async( dst_worker, my_script_func, args=(torch.tensor(1),) ) with self.assertRaisesRegex(RuntimeError, expected_error): fut.wait() # Ensure run to completion if we specify timeout of 0 rpc._set_rpc_timeout(0.001) fut = rpc.rpc_async( dst_worker, my_script_func, args=(torch.tensor(1),), timeout=0 ) fut.wait() # Reset for clean shutdown rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)