def test_some_kwargs_are_populated_by_defaults(self): if self.rank != 0: return dst_worker_name = worker_name((self.rank + 1) % self.world_size) args = (torch.tensor([1, 1]), torch.tensor([2, 2])) kwargs = {"first_kwarg": torch.tensor([2, 2])} for script_op in [ script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call ]: ret = script_op(dst_worker_name, args, kwargs) self.assertEqual(ret, torch.tensor([9, 9]))
def test_args_kwargs_are_neither_passed(self): if self.rank != 0: return dst_worker_name = worker_name((self.rank + 1) % self.world_size) @torch.jit.script def script_rpc_async_call_without_args_kwargs_passed( dst_worker_name: str, ): fut = rpc.rpc_async(dst_worker_name, no_arg) ret = fut.wait() return ret ret = script_rpc_async_call_without_args_kwargs_passed(dst_worker_name) self.assertEqual(ret, 0)
def __init__(self, world_size): self.ob_rrefs = [] self.agent_rref = RRef(self) self.rewards = {} self.saved_log_probs = {} self.policy = Policy() self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2) self.eps = np.finfo(np.float32).eps.item() self.running_reward = 0 self.reward_threshold = DummyEnv().reward_threshold for ob_rank in range(1, world_size): ob_info = rpc.get_worker_info(worker_name(ob_rank)) self.ob_rrefs.append(remote(ob_info, Observer)) self.rewards[ob_info.id] = [] self.saved_log_probs[ob_info.id] = []
def test_send_remote_module_over_the_wire(self): if self.rank != 0: return dst_worker1_name = dist_utils.worker_name( (self.rank + 1) % self.world_size) dst_worker2_name = dist_utils.worker_name( (self.rank + 2) % self.world_size) # Unpickled attribtes include both the inherent attributes of RemoteModule # (not inherited from the superclass) and two installed methods. expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES) expected_unpickled_attrs.append("forward_async") expected_unpickled_attrs.append("forward") # Create a remote module on worker1 and then pass it to worker2 over the RPC layer. for remote_module in self._create_remote_module_iter( dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR]): # Test querying some simple attributes from worker2. attrs = rpc.rpc_sync(dst_worker2_name, remote_module_attributes, (remote_module, )) self.assertListEqual(list(attrs.keys()), expected_unpickled_attrs) self.assertEqual(attrs["on"], "worker1") self.assertEqual(attrs["device"], "cpu") self.assertFalse(attrs["is_device_map_set"]) self.assertFalse(attrs["is_scriptable"]) # Test the installed methods on worker1's can be initiated by worker2 over RPC layer. # NOTE: In practice a remote module should be directly stored on the worker that runs ``forward``` or ``forward_async``, # not have another worker to initiate forward over the RPC layer. args = (torch.ones(1), 2, "3") ret1 = rpc.rpc_sync(dst_worker2_name, remote_forward, (remote_module, args)) self.assertEqual(ret1, tuple(reversed(args))) ret2 = rpc.rpc_sync(dst_worker2_name, remote_forward_async, (remote_module, args)) self.assertEqual(ret2, tuple(reversed(args)))
def test_remote_script_module(self): # TODO, need more investigation # there is rref leak when shutting down, suspect it is because # ref as arg is passed to pybind boundary, and the ref is not garbage # collected by python when calling shutdown() import torch.distributed.rpc.api as api api._ignore_rref_leak = True local_ret = torch.ones(self.rank) + torch.ones(self.rank) n = self.rank + 1 dst_rank = n % self.world_size remote_ref = rpc.remote(worker_name(dst_rank), construct_my_script_module, args=(self.rank, )) # pass rref arg to owner ret = rpc.rpc_sync( worker_name(dst_rank), run_ref_script_module, args=(remote_ref, torch.ones(self.rank)), ) self.assertEqual(ret, local_ret)
def test_invalid_devices(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) with self.assertRaisesRegex( RuntimeError, r"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan" " device type at start of device string", ): list( self._create_remote_module_iter( dst_worker_name, device="foo", modes=[ModuleCreationMode.MODULE_CTOR], ) ) with self.assertRaisesRegex( RuntimeError, r"CUDA error: invalid device ordinal" ): list( self._create_remote_module_iter( dst_worker_name, device="cuda:100", modes=[ModuleCreationMode.MODULE_CTOR], ) ) with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"): list( self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR], device="cpu2", ) ) with self.assertRaisesRegex( RuntimeError, r"CPU device index must be -1 or zero, got 2" ): list( self._create_remote_module_iter( dst_worker_name, device="cpu:2", modes=[ModuleCreationMode.MODULE_CTOR], ) )
def test_dist_backward(self): if self.rank != 0: return @torch.jit.script def dist_backward_script(context_id: int, loss: torch.Tensor): dist_autograd.backward(context_id, [loss]) FileCheck().check("dist_backward").run(str(dist_backward_script.graph)) with dist_autograd.context() as context_id: t1 = torch.rand(3, 3, requires_grad=True) t2 = torch.rand(3, 3, requires_grad=True) dst_worker_name = worker_name((self.rank + 1) % self.world_size) loss = rpc.rpc_sync(dst_worker_name, torch.add, args=(t1, t2)).sum() dist_backward_script(context_id, loss)
def test_forward_with_kwargs(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) args = (torch.ones(1), 2) kwargs = dict(word="3") # Only test Python nn.Module, because script module methods don't support taking kwargs. for remote_module in self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] ): ret_fut = remote_module.forward_async(*args, **kwargs) ret = ret_fut.wait() self.assertEqual(ret, tuple(reversed(args + ("3",)))) ret = remote_module.forward(*args, **kwargs) self.assertEqual(ret, tuple(reversed(args + ("3",))))
def test_load_script_module_with_pickled_rref(self): dst_name = worker_name((self.rank + 1) % self.world_size) m1 = MyScriptModuleWithRRefs(dst_name) m2 = MyScriptModuleWithRRefs(dst_name) f = io.BytesIO() rpc._enable_jit_rref_pickle() torch.jit.save(m1, f) rpc._disable_jit_rref_pickle() out1 = rpc.rpc_sync(dst_name, load_script_module_with_pickled_rref, args=(f.getvalue(), )) out2 = m2() self.assertEqual(out1, out2)
def test_kwargs_not_passed(self): if self.rank != 0: return dst_worker_name = worker_name((self.rank + 1) % self.world_size) @torch.jit.script def rpc_async_call_remote_torchscript_in_torchscript_without_kwargs_passed( dst_worker_name: str): args = () fut = rpc.rpc_async(dst_worker_name, no_arg, args) ret = fut.wait() return ret ret = rpc_async_call_remote_torchscript_in_torchscript_without_kwargs_passed( dst_worker_name) self.assertEqual(ret, 0)
def test_less_than_needed_args_are_specified(self): if self.rank != 0: return dst_worker_name = worker_name((self.rank + 1) % self.world_size) # Notice, args matching happens during scripting. with self.assertRaisesRegex(RuntimeError, "Argument second_arg not provided"): @torch.jit.script def rpc_async_call_remote_torchscript_in_torchscript_with_less_args( dst_worker_name: str, # noqa: E999 ): args = (torch.tensor([1, 1]),) kwargs = {} fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) ret = fut.wait() return ret
def test_send_remote_module_with_a_new_attribute_ignored_over_the_wire( self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name( (self.rank + 1) % self.world_size) # If add a new attribute is added to this RemoteModule, which will be sent over the wire by RPC, # this new field must be added to either _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING # to avoid runtime error. for remote_module in self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]): new_attr_name = "new_attr" setattr(remote_module, new_attr_name, 1) attrs = rpc.rpc_sync(dst_worker_name, remote_module_attributes, (remote_module, )) self.assertNotIn(new_attr_name, attrs)
def test_bad_module(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) args = (1,) kwargs = dict(first_kwarg=2) with self.assertRaisesRegex( ValueError, r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,", ): RemoteModule(dst_worker_name, BadModule, args, kwargs) with self.assertRaisesRegex( ValueError, r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,", ): RemoteModule(dst_worker_name, BadModule, args, kwargs)
def test_forward_sync_script(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) scripted_remote_module = next( self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] ) ) @torch.jit.script def run_forward(scripted_remote_module: MyModuleInterface): ret = scripted_remote_module.forward(torch.ones(1), 2, "3") return ret ret = run_forward(scripted_remote_module) self.assertEqual(ret, ("3", 2, torch.ones(1)))
def test_call_python_function_remotely_from_script_not_supported(self): if self.rank != 0: return dst_worker_name = worker_name((self.rank + 1) % self.world_size) @torch.jit.script def rpc_async_call_remote_py_function_in_torchscript(dst_worker_name: str): args = () kwargs = {} fut = rpc.rpc_async(dst_worker_name, python_function, args, kwargs) ret = fut.wait() return ret with self.assertRaisesRegex( RuntimeError, "attempted to get undefined function" ): ret = rpc_async_call_remote_py_function_in_torchscript(dst_worker_name) self.assertEqual(ret, 0)
def test_train_eval(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name( (self.rank + 1) % self.world_size) for remote_module in self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]): remote_module.train() ret1 = rpc.rpc_sync(dst_worker_name, get_remote_training_arg, args=(remote_module.get_module_rref(), )) self.assertEqual(ret1, True) remote_module.eval() ret2 = rpc.rpc_sync(dst_worker_name, get_remote_training_arg, args=(remote_module.get_module_rref(), )) self.assertEqual(ret2, False)
def test_rpc_builtin_timeout(self): next_rank = (self.rank + 1) % self.world_size dst_worker = worker_name(next_rank) expected_error = self.get_timeout_error_regex() # PYTHON_CALL message types which correspond to Python UDF over RPC # by default get a delay (see faulty_rpc_agent_test_fixture) with self.assertRaisesRegex(RuntimeError, expected_error): rpc.rpc_sync( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1, ) fut = rpc.rpc_async( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1 ) with self.assertRaisesRegex(RuntimeError, expected_error): fut.wait() # Ensure that the currently set default timeout is large enough such # that RPCs with delays still complete. fut = rpc.rpc_async( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) fut.wait() # Ensure timeout if we set a new default and don't override rpc._set_rpc_timeout(0.001) fut = rpc.rpc_async( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) with self.assertRaisesRegex(RuntimeError, expected_error): fut.wait() # Ensure run to completion if we specify timeout of 0 fut = rpc.rpc_async( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=0 ) fut.wait() # Reset for clean shutdown rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
def test_send_remote_module_with_a_new_attribute_not_pickled_over_the_wire(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) # If a new attribute is added to this RemoteModule after the initialization, # and it will be sent over the wire by RPC, # this new field will not be pickled, because it's not specified in _REMOTE_MODULE_PICKLED_ATTRIBUTES. # Note that adding a new attribute out of constructor should rarely happen. # If a new attribute is added to RemoteModule constructor, # there is a sanity check to enforce developers to add this attribute to either # _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING. for remote_module in self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] ): new_attr_name = "new_attr" setattr(remote_module, new_attr_name, 1) attrs = rpc.rpc_sync(dst_worker_name, remote_module_attributes, (remote_module,)) self.assertNotIn(new_attr_name, attrs)
def test_torchscript_functions_not_supported(self): dst_worker_name = worker_name((self.rank + 1) % self.world_size) # rpc_sync still accepts script class and run it in # the same code path as python call. ret = rpc.rpc_sync(dst_worker_name, MyScriptClass, args=(self.rank, )) # rpc_sync does not accept script module and script module method. with self.assertRaisesRegex(RuntimeError, "ScriptModules cannot be deepcopied"): ret = rpc.rpc_sync(dst_worker_name, MyScriptModule, args=(self.rank, )) # Python 3.5 and Python 3.6 throw different error message, the only # common word can be greped is "pickle". with self.assertRaisesRegex(TypeError, "pickle"): ret = rpc.rpc_async(dst_worker_name, MyScriptModule(self.rank).forward, args=())
def test_rpc_async_jit_profiled(self): # Tests that rpc_async calls made from within a TorchScript function are # profiled. if self.rank == 0: dst_rank = (self.rank + 1) % self.world_size dst_worker_name = worker_name(dst_rank) args = (torch.tensor([1, 1]), torch.tensor([2, 2])) kwargs = {} with torch.autograd.profiler.profile() as prof: rpc_async_call_remote_torchscript_in_torchscript( dst_worker_name, args, kwargs) # Ensure rpc_async call is profiled function_events = prof.function_events qual_name = torch._jit_internal._qualified_name( two_args_two_kwargs) rpc_async_jit_event = [ event for event in function_events if qual_name in event.name and event.node_id == self.rank ] self.assertEqual(len(rpc_async_jit_event), 1) rpc_async_jit_event = rpc_async_jit_event[0] profiled_name = f"rpc_async_jit#({qual_name})#({worker_name(self.rank)})->({dst_worker_name})" self.assertEqual(profiled_name, rpc_async_jit_event.name) remote_events = [ event for event in function_events if event.is_remote ] # All remote events should have taken place on dst_rank remote_event_node_ids = { remote_event.node_id for remote_event in remote_events } self.assertEqual(remote_event_node_ids, {dst_rank}) # rpc_async_call_remote_torchscript_in_torchscript invokes add operator # so we should see this as a remote event. remote_add = [ remote_event for remote_event in remote_events if "aten::add" in remote_event.name ][0] remote_add_profiled_name = f"{profiled_name}#remote_op: aten::add" self.assertEqual(remote_add.name, remote_add_profiled_name)
def test_input_moved_to_cuda_device_script(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) scripted_remote_module = next( self._create_remote_module_iter( "{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] ) ) @torch.jit.script def run_forward(scripted_remote_module: MyModuleInterface): ret = scripted_remote_module.forward(torch.ones(1), 2, "3") return ret ret = run_forward(scripted_remote_module) self.assertEqual(ret, ("3", 2, torch.ones(1))) # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0". self.assertEqual(ret[2].device.type, "cpu")
def test_future_python_annotation(self): if self.rank != 0: return dst_worker_name = worker_name((self.rank + 1) % self.world_size) input_0 = torch.ones(2, 2) input_1 = 1 expected_res = torch.add(input_0, input_1) @torch.jit.ignore def python_return_future() -> Future[Tensor]: fut = rpc.rpc_async(dst_worker_name, torch.add, (input_0, input_1), {}) return fut @torch.jit.script def script_use_future() -> Tensor: fut = python_return_future() return fut.wait() res = script_use_future() self.assertEqual(res, expected_res)
def test_send_remote_module_with_a_new_attribute_over_the_wire(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name( (self.rank + 1) % self.world_size) # If add a new attribute is added to this RemoteModule, which will be sent over the wire by RPC, # this new field must be added to either _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING # to avoid runtime error. for remote_module in self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]): new_attr_name = "new_attr" setattr(remote_module, new_attr_name, 1) with self.assertRaisesRegex( RuntimeError, "Attribute ``{}`` of RemoteModule must be either in " "``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` or ``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``" .format(new_attr_name), ): rpc.rpc_sync(dst_worker_name, remote_module_attributes, (remote_module, ))
def test_call_script_function_that_raises_remotely_from_script(self): if self.rank != 0: return dst_worker_name = worker_name((self.rank + 1) % self.world_size) # Notice, TorchScript always translates(emits) Python `raise` statement, # as the exception message string, "Exception", # no matter what exception type and excetpion message are in the statement, @torch.jit.script def rpc_async_call_remote_raising_torchscript_in_torchscript( dst_worker_name: str): args = () kwargs = {} fut = rpc.rpc_async(dst_worker_name, raise_script, args, kwargs) ret = fut.wait() return ret with self.assertRaisesRegex(RuntimeError, "Exception"): ret = rpc_async_call_remote_raising_torchscript_in_torchscript( dst_worker_name) self.assertEqual(ret, 0)
def test_future_passed_between_python_and_jit(self): dst_rank = (self.rank + 1) % self.world_size inputs = (torch.tensor([1, 1]), torch.tensor([2, 2])) ret_fut = rpc.rpc_async(worker_name(dst_rank), two_args_two_kwargs, args=inputs) expected_res = torch.tensor([10, 10]) @torch.jit.script def future_wait_in_script(fut: Future[Tensor]) -> Tensor: return fut.wait() self.assertEqual(future_wait_in_script(ret_fut), expected_res) @torch.jit.script def future_return_to_python( dst_rank: int, inputs: Tuple[Tensor, Tensor] ) -> Future[Tensor]: return rpc.rpc_async( "worker{}".format(dst_rank), two_args_two_kwargs, inputs ) fut_res = future_return_to_python(dst_rank, inputs) self.assertEqual(fut_res.wait(), expected_res)
def test_unexepected_kwarg_is_specified(self): if self.rank != 0: return dst_worker_name = worker_name((self.rank + 1) % self.world_size) # Notice, kwargs matching happens during execution. @torch.jit.script def script_rpc_async_call_with_unexpected_kwarg( dst_worker_name: str, # noqa: E999 ): args = (torch.tensor([1, 1]), torch.tensor([2, 2])) kwargs = {"third_kwarg": torch.tensor([1, 1])} fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs) ret = fut.wait() return ret with self.assertRaisesRegex(RuntimeError, "Unknown keyword argument 'third_kwarg'"): ret = script_rpc_async_call_with_unexpected_kwarg(dst_worker_name) self.assertEqual(ret, 0)
def test_add_done_callback(self): callback_called = None def callback(fut): nonlocal callback_called callback_called = fut.wait() * 2 future = rpc.rpc_async( worker_name((self.rank + 1) % self.world_size), script_fork_wait_udf, args=(torch.ones(2), ), ) future.add_done_callback(callback) future_then = future.then(lambda _: True) self.assertEqual(future.wait(), torch.ones(2) * 2) # We have no guarantee that the add_done_callback fn will execute before the test finishes. # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback future_then.wait() self.assertEqual(callback_called, torch.ones(2) * 4)
def test_valid_device(self): if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker_name = dist_utils.worker_name(dst_rank) for remote_module in self._create_remote_module_iter( "{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR]): device = rpc.rpc_sync(dst_worker_name, remote_device, (remote_module.module_rref, )) self.assertEqual(device.type, "cuda") self.assertEqual(device.index, 0) # Test rank works as well. for remote_module in self._create_remote_module_iter( "rank:{}/cuda:0".format(dst_rank), modes=[ModuleCreationMode.MODULE_CTOR]): device = rpc.rpc_sync(dst_worker_name, remote_device, (remote_module.module_rref, )) self.assertEqual(device.type, "cuda") self.assertEqual(device.index, 0)
def test_create_script_module_on_remote(self): dst_name = worker_name((self.rank + 1) % self.world_size) # Construct on remote end with rpc_sync created_script_module = rpc.rpc_sync(dst_name, MyScriptModule, args=(self.rank, )) # Forward should output a ones tensor of self.rank. self.assertTrue( isinstance(created_script_module, torch.jit.ScriptModule)) rank_ones_tensor = created_script_module() self.assertEqual(torch.ones(self.rank), rank_ones_tensor) # Construct ScriptModule with rpc.remote. remote_script_module = rpc.remote(dst_name, MyScriptModule, args=(self.rank, )) # Verify it is an instance of ScriptModule on remote end. remote_end_is_script = rpc.rpc_sync( remote_script_module.owner(), rref_isinstance, args=(remote_script_module, torch.jit.ScriptModule), ) self.assertTrue(remote_end_is_script) # Run forward pass remotely. remote_forward_output = remote_script_module.rpc_sync().forward() self.assertEqual(remote_forward_output, torch.ones(self.rank)) # Run function defined on ScriptModule remotely. remote_func_output = remote_script_module.rpc_sync().custom_func() self.assertEqual(remote_func_output, torch.ones(self.rank)) # Ensure we can transfer ScriptModule RRef to this rank and run # forward pass. local_script_module = remote_script_module.to_here() self.assertTrue(isinstance(local_script_module, torch.jit.ScriptModule)) rank_ones_tensor = local_script_module() self.assertEqual(rank_ones_tensor, torch.ones(self.rank)) local_script_func_output = local_script_module.custom_func() self.assertEqual(local_script_func_output, torch.ones(self.rank))
def test_torchscript_functions_not_supported(self): dst_worker_name = worker_name((self.rank + 1) % self.world_size) my_local_script_module = MyScriptModule(self.rank) # It is not thread safe to instantiate MyScriptModule in multiple threads, # wait for local MyScriptModule instantiation to finish, # otherwise it could instantiate MyScriptModule in parallel with # server thread in the below initialize_pg(self.file_init_method, self.rank, self.world_size) dist.barrier() # rpc_sync still accepts script class and run it in # the same code path as python call. ret = rpc.rpc_sync(dst_worker_name, MyScriptClass, args=(self.rank, )) # rpc_sync does not accept script module method. # Python 3.5 and Python 3.6 throw different error message, the only # common word can be greped is "pickle". with self.assertRaisesRegex(TypeError, "pickle"): ret = rpc.rpc_async(dst_worker_name, my_local_script_module.forward, args=())