Exemplo n.º 1
0
    def test_some_kwargs_are_populated_by_defaults(self):
        if self.rank != 0:
            return

        dst_worker_name = worker_name((self.rank + 1) % self.world_size)

        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
        kwargs = {"first_kwarg": torch.tensor([2, 2])}

        for script_op in [
                script_rpc_async_call, script_rpc_sync_call,
                script_rpc_remote_call
        ]:
            ret = script_op(dst_worker_name, args, kwargs)
            self.assertEqual(ret, torch.tensor([9, 9]))
Exemplo n.º 2
0
    def test_args_kwargs_are_neither_passed(self):
        if self.rank != 0:
            return

        dst_worker_name = worker_name((self.rank + 1) % self.world_size)

        @torch.jit.script
        def script_rpc_async_call_without_args_kwargs_passed(
            dst_worker_name: str, ):
            fut = rpc.rpc_async(dst_worker_name, no_arg)
            ret = fut.wait()
            return ret

        ret = script_rpc_async_call_without_args_kwargs_passed(dst_worker_name)
        self.assertEqual(ret, 0)
 def __init__(self, world_size):
     self.ob_rrefs = []
     self.agent_rref = RRef(self)
     self.rewards = {}
     self.saved_log_probs = {}
     self.policy = Policy()
     self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
     self.eps = np.finfo(np.float32).eps.item()
     self.running_reward = 0
     self.reward_threshold = DummyEnv().reward_threshold
     for ob_rank in range(1, world_size):
         ob_info = rpc.get_worker_info(worker_name(ob_rank))
         self.ob_rrefs.append(remote(ob_info, Observer))
         self.rewards[ob_info.id] = []
         self.saved_log_probs[ob_info.id] = []
Exemplo n.º 4
0
    def test_send_remote_module_over_the_wire(self):
        if self.rank != 0:
            return
        dst_worker1_name = dist_utils.worker_name(
            (self.rank + 1) % self.world_size)
        dst_worker2_name = dist_utils.worker_name(
            (self.rank + 2) % self.world_size)

        # Unpickled attribtes include both the inherent attributes of RemoteModule
        # (not inherited from the superclass) and two installed methods.
        expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES)
        expected_unpickled_attrs.append("forward_async")
        expected_unpickled_attrs.append("forward")

        # Create a remote module on worker1 and then pass it to worker2 over the RPC layer.
        for remote_module in self._create_remote_module_iter(
                dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR]):
            # Test querying some simple attributes from worker2.
            attrs = rpc.rpc_sync(dst_worker2_name, remote_module_attributes,
                                 (remote_module, ))
            self.assertListEqual(list(attrs.keys()), expected_unpickled_attrs)
            self.assertEqual(attrs["on"], "worker1")
            self.assertEqual(attrs["device"], "cpu")
            self.assertFalse(attrs["is_device_map_set"])
            self.assertFalse(attrs["is_scriptable"])

            # Test the installed methods on worker1's can be initiated by worker2 over RPC layer.
            # NOTE: In practice a remote module should be directly stored on the worker that runs ``forward``` or ``forward_async``,
            # not have another worker to initiate forward over the RPC layer.
            args = (torch.ones(1), 2, "3")
            ret1 = rpc.rpc_sync(dst_worker2_name, remote_forward,
                                (remote_module, args))
            self.assertEqual(ret1, tuple(reversed(args)))
            ret2 = rpc.rpc_sync(dst_worker2_name, remote_forward_async,
                                (remote_module, args))
            self.assertEqual(ret2, tuple(reversed(args)))
Exemplo n.º 5
0
    def test_remote_script_module(self):
        # TODO, need more investigation
        # there is rref leak when shutting down, suspect it is because
        # ref as arg is passed to pybind boundary, and the ref is not garbage
        # collected by python when calling shutdown()
        import torch.distributed.rpc.api as api

        api._ignore_rref_leak = True

        local_ret = torch.ones(self.rank) + torch.ones(self.rank)

        n = self.rank + 1
        dst_rank = n % self.world_size
        remote_ref = rpc.remote(worker_name(dst_rank),
                                construct_my_script_module,
                                args=(self.rank, ))

        # pass rref arg to owner
        ret = rpc.rpc_sync(
            worker_name(dst_rank),
            run_ref_script_module,
            args=(remote_ref, torch.ones(self.rank)),
        )
        self.assertEqual(ret, local_ret)
Exemplo n.º 6
0
    def test_invalid_devices(self):
        if self.rank != 0:
            return
        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)

        with self.assertRaisesRegex(
            RuntimeError,
            r"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan"
            " device type at start of device string",
        ):
            list(
                self._create_remote_module_iter(
                    dst_worker_name,
                    device="foo",
                    modes=[ModuleCreationMode.MODULE_CTOR],
                )
            )

        with self.assertRaisesRegex(
            RuntimeError, r"CUDA error: invalid device ordinal"
        ):
            list(
                self._create_remote_module_iter(
                    dst_worker_name,
                    device="cuda:100",
                    modes=[ModuleCreationMode.MODULE_CTOR],
                )
            )

        with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"):
            list(
                self._create_remote_module_iter(
                    dst_worker_name,
                    modes=[ModuleCreationMode.MODULE_CTOR],
                    device="cpu2",
                )
            )

        with self.assertRaisesRegex(
            RuntimeError, r"CPU device index must be -1 or zero, got 2"
        ):
            list(
                self._create_remote_module_iter(
                    dst_worker_name,
                    device="cpu:2",
                    modes=[ModuleCreationMode.MODULE_CTOR],
                )
            )
Exemplo n.º 7
0
    def test_dist_backward(self):
        if self.rank != 0:
            return

        @torch.jit.script
        def dist_backward_script(context_id: int, loss: torch.Tensor):
            dist_autograd.backward(context_id, [loss])

        FileCheck().check("dist_backward").run(str(dist_backward_script.graph))
        with dist_autograd.context() as context_id:
            t1 = torch.rand(3, 3, requires_grad=True)
            t2 = torch.rand(3, 3, requires_grad=True)
            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
            loss = rpc.rpc_sync(dst_worker_name, torch.add,
                                args=(t1, t2)).sum()
            dist_backward_script(context_id, loss)
Exemplo n.º 8
0
    def test_forward_with_kwargs(self):
        if self.rank != 0:
            return
        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
        args = (torch.ones(1), 2)
        kwargs = dict(word="3")
        # Only test Python nn.Module, because script module methods don't support taking kwargs.
        for remote_module in self._create_remote_module_iter(
            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
        ):
            ret_fut = remote_module.forward_async(*args, **kwargs)
            ret = ret_fut.wait()
            self.assertEqual(ret, tuple(reversed(args + ("3",))))

            ret = remote_module.forward(*args, **kwargs)
            self.assertEqual(ret, tuple(reversed(args + ("3",))))
Exemplo n.º 9
0
    def test_load_script_module_with_pickled_rref(self):
        dst_name = worker_name((self.rank + 1) % self.world_size)
        m1 = MyScriptModuleWithRRefs(dst_name)
        m2 = MyScriptModuleWithRRefs(dst_name)

        f = io.BytesIO()

        rpc._enable_jit_rref_pickle()
        torch.jit.save(m1, f)
        rpc._disable_jit_rref_pickle()

        out1 = rpc.rpc_sync(dst_name,
                            load_script_module_with_pickled_rref,
                            args=(f.getvalue(), ))
        out2 = m2()
        self.assertEqual(out1, out2)
Exemplo n.º 10
0
    def test_kwargs_not_passed(self):
        if self.rank != 0:
            return

        dst_worker_name = worker_name((self.rank + 1) % self.world_size)

        @torch.jit.script
        def rpc_async_call_remote_torchscript_in_torchscript_without_kwargs_passed(
                dst_worker_name: str):
            args = ()
            fut = rpc.rpc_async(dst_worker_name, no_arg, args)
            ret = fut.wait()
            return ret

        ret = rpc_async_call_remote_torchscript_in_torchscript_without_kwargs_passed(
            dst_worker_name)
        self.assertEqual(ret, 0)
Exemplo n.º 11
0
    def test_less_than_needed_args_are_specified(self):
        if self.rank != 0:
            return

        dst_worker_name = worker_name((self.rank + 1) % self.world_size)

        # Notice, args matching happens during scripting.
        with self.assertRaisesRegex(RuntimeError, "Argument second_arg not provided"):

            @torch.jit.script
            def rpc_async_call_remote_torchscript_in_torchscript_with_less_args(
                dst_worker_name: str,  # noqa: E999
            ):
                args = (torch.tensor([1, 1]),)
                kwargs = {}
                fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
                ret = fut.wait()
                return ret
Exemplo n.º 12
0
    def test_send_remote_module_with_a_new_attribute_ignored_over_the_wire(
            self):
        if self.rank != 0:
            return
        dst_worker_name = dist_utils.worker_name(
            (self.rank + 1) % self.world_size)

        # If add a new attribute is added to this RemoteModule, which will be sent over the wire by RPC,
        # this new field must be added to either _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING
        # to avoid runtime error.
        for remote_module in self._create_remote_module_iter(
                dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]):
            new_attr_name = "new_attr"
            setattr(remote_module, new_attr_name, 1)

            attrs = rpc.rpc_sync(dst_worker_name, remote_module_attributes,
                                 (remote_module, ))
            self.assertNotIn(new_attr_name, attrs)
Exemplo n.º 13
0
    def test_bad_module(self):
        if self.rank != 0:
            return
        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
        args = (1,)
        kwargs = dict(first_kwarg=2)

        with self.assertRaisesRegex(
            ValueError,
            r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,",
        ):
            RemoteModule(dst_worker_name, BadModule, args, kwargs)

        with self.assertRaisesRegex(
            ValueError,
            r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,",
        ):
            RemoteModule(dst_worker_name, BadModule, args, kwargs)
Exemplo n.º 14
0
    def test_forward_sync_script(self):
        if self.rank != 0:
            return
        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)

        scripted_remote_module = next(
            self._create_remote_module_iter(
                dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
            )
        )

        @torch.jit.script
        def run_forward(scripted_remote_module: MyModuleInterface):
            ret = scripted_remote_module.forward(torch.ones(1), 2, "3")
            return ret

        ret = run_forward(scripted_remote_module)

        self.assertEqual(ret, ("3", 2, torch.ones(1)))
Exemplo n.º 15
0
    def test_call_python_function_remotely_from_script_not_supported(self):
        if self.rank != 0:
            return

        dst_worker_name = worker_name((self.rank + 1) % self.world_size)

        @torch.jit.script
        def rpc_async_call_remote_py_function_in_torchscript(dst_worker_name: str):
            args = ()
            kwargs = {}
            fut = rpc.rpc_async(dst_worker_name, python_function, args, kwargs)
            ret = fut.wait()
            return ret

        with self.assertRaisesRegex(
            RuntimeError, "attempted to get undefined function"
        ):
            ret = rpc_async_call_remote_py_function_in_torchscript(dst_worker_name)
            self.assertEqual(ret, 0)
Exemplo n.º 16
0
    def test_train_eval(self):
        if self.rank != 0:
            return
        dst_worker_name = dist_utils.worker_name(
            (self.rank + 1) % self.world_size)

        for remote_module in self._create_remote_module_iter(
                dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]):
            remote_module.train()
            ret1 = rpc.rpc_sync(dst_worker_name,
                                get_remote_training_arg,
                                args=(remote_module.get_module_rref(), ))
            self.assertEqual(ret1, True)

            remote_module.eval()
            ret2 = rpc.rpc_sync(dst_worker_name,
                                get_remote_training_arg,
                                args=(remote_module.get_module_rref(), ))
            self.assertEqual(ret2, False)
Exemplo n.º 17
0
    def test_rpc_builtin_timeout(self):
        next_rank = (self.rank + 1) % self.world_size
        dst_worker = worker_name(next_rank)
        expected_error = self.get_timeout_error_regex()
        # PYTHON_CALL message types which correspond to Python UDF over RPC
        # by default get a delay (see faulty_rpc_agent_test_fixture)
        with self.assertRaisesRegex(RuntimeError, expected_error):
            rpc.rpc_sync(
                dst_worker,
                torch.add,
                args=(torch.tensor(1), torch.tensor(1)),
                timeout=1,
            )

        fut = rpc.rpc_async(
            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1
        )
        with self.assertRaisesRegex(RuntimeError, expected_error):
            fut.wait()

        # Ensure that the currently set default timeout is large enough such
        # that RPCs with delays still complete.
        fut = rpc.rpc_async(
            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
        )
        fut.wait()

        # Ensure timeout if we set a new default and don't override
        rpc._set_rpc_timeout(0.001)
        fut = rpc.rpc_async(
            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
        )
        with self.assertRaisesRegex(RuntimeError, expected_error):
            fut.wait()

        # Ensure run to completion if we specify timeout of 0
        fut = rpc.rpc_async(
            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=0
        )
        fut.wait()
        # Reset for clean shutdown
        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
Exemplo n.º 18
0
    def test_send_remote_module_with_a_new_attribute_not_pickled_over_the_wire(self):
        if self.rank != 0:
            return
        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)

        # If a new attribute is added to this RemoteModule after the initialization,
        # and it will be sent over the wire by RPC,
        # this new field will not be pickled, because it's not specified in _REMOTE_MODULE_PICKLED_ATTRIBUTES.
        # Note that adding a new attribute out of constructor should rarely happen.
        # If a new attribute is added to RemoteModule constructor,
        # there is a sanity check to enforce developers to add this attribute to either
        # _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
        for remote_module in self._create_remote_module_iter(
            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
        ):
            new_attr_name = "new_attr"
            setattr(remote_module, new_attr_name, 1)

            attrs = rpc.rpc_sync(dst_worker_name, remote_module_attributes, (remote_module,))
            self.assertNotIn(new_attr_name, attrs)
Exemplo n.º 19
0
    def test_torchscript_functions_not_supported(self):
        dst_worker_name = worker_name((self.rank + 1) % self.world_size)

        # rpc_sync still accepts script class and run it in
        # the same code path as python call.
        ret = rpc.rpc_sync(dst_worker_name, MyScriptClass, args=(self.rank, ))

        # rpc_sync does not accept script module and script module method.
        with self.assertRaisesRegex(RuntimeError,
                                    "ScriptModules cannot be deepcopied"):
            ret = rpc.rpc_sync(dst_worker_name,
                               MyScriptModule,
                               args=(self.rank, ))

        # Python 3.5 and Python 3.6 throw different error message, the only
        # common word can be greped is "pickle".
        with self.assertRaisesRegex(TypeError, "pickle"):
            ret = rpc.rpc_async(dst_worker_name,
                                MyScriptModule(self.rank).forward,
                                args=())
Exemplo n.º 20
0
    def test_rpc_async_jit_profiled(self):
        # Tests that rpc_async calls made from within a TorchScript function are
        # profiled.
        if self.rank == 0:
            dst_rank = (self.rank + 1) % self.world_size
            dst_worker_name = worker_name(dst_rank)
            args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
            kwargs = {}
            with torch.autograd.profiler.profile() as prof:
                rpc_async_call_remote_torchscript_in_torchscript(
                    dst_worker_name, args, kwargs)

            # Ensure rpc_async call is profiled
            function_events = prof.function_events
            qual_name = torch._jit_internal._qualified_name(
                two_args_two_kwargs)
            rpc_async_jit_event = [
                event for event in function_events
                if qual_name in event.name and event.node_id == self.rank
            ]
            self.assertEqual(len(rpc_async_jit_event), 1)
            rpc_async_jit_event = rpc_async_jit_event[0]
            profiled_name = f"rpc_async_jit#({qual_name})#({worker_name(self.rank)})->({dst_worker_name})"
            self.assertEqual(profiled_name, rpc_async_jit_event.name)
            remote_events = [
                event for event in function_events if event.is_remote
            ]
            # All remote events should have taken place on dst_rank
            remote_event_node_ids = {
                remote_event.node_id
                for remote_event in remote_events
            }
            self.assertEqual(remote_event_node_ids, {dst_rank})
            # rpc_async_call_remote_torchscript_in_torchscript invokes add operator
            # so we should see this as a remote event.
            remote_add = [
                remote_event for remote_event in remote_events
                if "aten::add" in remote_event.name
            ][0]
            remote_add_profiled_name = f"{profiled_name}#remote_op: aten::add"
            self.assertEqual(remote_add.name, remote_add_profiled_name)
    def test_input_moved_to_cuda_device_script(self):
        if self.rank != 0:
            return
        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)

        scripted_remote_module = next(
            self._create_remote_module_iter(
                "{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
            )
        )

        @torch.jit.script
        def run_forward(scripted_remote_module: MyModuleInterface):
            ret = scripted_remote_module.forward(torch.ones(1), 2, "3")
            return ret

        ret = run_forward(scripted_remote_module)

        self.assertEqual(ret, ("3", 2, torch.ones(1)))
        # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0".
        self.assertEqual(ret[2].device.type, "cpu")
Exemplo n.º 22
0
    def test_future_python_annotation(self):
        if self.rank != 0:
            return

        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
        input_0 = torch.ones(2, 2)
        input_1 = 1
        expected_res = torch.add(input_0, input_1)

        @torch.jit.ignore
        def python_return_future() -> Future[Tensor]:
            fut = rpc.rpc_async(dst_worker_name, torch.add, (input_0, input_1), {})
            return fut

        @torch.jit.script
        def script_use_future() -> Tensor:
            fut = python_return_future()
            return fut.wait()

        res = script_use_future()
        self.assertEqual(res, expected_res)
Exemplo n.º 23
0
    def test_send_remote_module_with_a_new_attribute_over_the_wire(self):
        if self.rank != 0:
            return
        dst_worker_name = dist_utils.worker_name(
            (self.rank + 1) % self.world_size)

        # If add a new attribute is added to this RemoteModule, which will be sent over the wire by RPC,
        # this new field must be added to either _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING
        # to avoid runtime error.
        for remote_module in self._create_remote_module_iter(
                dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]):
            new_attr_name = "new_attr"
            setattr(remote_module, new_attr_name, 1)
            with self.assertRaisesRegex(
                    RuntimeError,
                    "Attribute ``{}`` of RemoteModule must be either in "
                    "``_REMOTE_MODULE_PICKLED_ATTRIBUTES``  or ``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``"
                    .format(new_attr_name),
            ):
                rpc.rpc_sync(dst_worker_name, remote_module_attributes,
                             (remote_module, ))
Exemplo n.º 24
0
    def test_call_script_function_that_raises_remotely_from_script(self):
        if self.rank != 0:
            return

        dst_worker_name = worker_name((self.rank + 1) % self.world_size)

        # Notice, TorchScript always translates(emits) Python `raise` statement,
        # as the exception message string, "Exception",
        # no matter what exception type and excetpion message are in the statement,
        @torch.jit.script
        def rpc_async_call_remote_raising_torchscript_in_torchscript(
                dst_worker_name: str):
            args = ()
            kwargs = {}
            fut = rpc.rpc_async(dst_worker_name, raise_script, args, kwargs)
            ret = fut.wait()
            return ret

        with self.assertRaisesRegex(RuntimeError, "Exception"):
            ret = rpc_async_call_remote_raising_torchscript_in_torchscript(
                dst_worker_name)
            self.assertEqual(ret, 0)
Exemplo n.º 25
0
    def test_future_passed_between_python_and_jit(self):
        dst_rank = (self.rank + 1) % self.world_size
        inputs = (torch.tensor([1, 1]), torch.tensor([2, 2]))
        ret_fut = rpc.rpc_async(worker_name(dst_rank), two_args_two_kwargs, args=inputs)
        expected_res = torch.tensor([10, 10])

        @torch.jit.script
        def future_wait_in_script(fut: Future[Tensor]) -> Tensor:
            return fut.wait()

        self.assertEqual(future_wait_in_script(ret_fut), expected_res)

        @torch.jit.script
        def future_return_to_python(
            dst_rank: int, inputs: Tuple[Tensor, Tensor]
        ) -> Future[Tensor]:
            return rpc.rpc_async(
                "worker{}".format(dst_rank), two_args_two_kwargs, inputs
            )

        fut_res = future_return_to_python(dst_rank, inputs)
        self.assertEqual(fut_res.wait(), expected_res)
Exemplo n.º 26
0
    def test_unexepected_kwarg_is_specified(self):
        if self.rank != 0:
            return

        dst_worker_name = worker_name((self.rank + 1) % self.world_size)

        # Notice, kwargs matching happens during execution.
        @torch.jit.script
        def script_rpc_async_call_with_unexpected_kwarg(
                dst_worker_name: str,  # noqa: E999
        ):
            args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
            kwargs = {"third_kwarg": torch.tensor([1, 1])}
            fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args,
                                kwargs)
            ret = fut.wait()
            return ret

        with self.assertRaisesRegex(RuntimeError,
                                    "Unknown keyword argument 'third_kwarg'"):
            ret = script_rpc_async_call_with_unexpected_kwarg(dst_worker_name)
            self.assertEqual(ret, 0)
Exemplo n.º 27
0
    def test_add_done_callback(self):
        callback_called = None

        def callback(fut):
            nonlocal callback_called
            callback_called = fut.wait() * 2

        future = rpc.rpc_async(
            worker_name((self.rank + 1) % self.world_size),
            script_fork_wait_udf,
            args=(torch.ones(2), ),
        )

        future.add_done_callback(callback)
        future_then = future.then(lambda _: True)

        self.assertEqual(future.wait(), torch.ones(2) * 2)

        # We have no guarantee that the add_done_callback fn will execute before the test finishes.
        # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback
        future_then.wait()
        self.assertEqual(callback_called, torch.ones(2) * 4)
Exemplo n.º 28
0
    def test_valid_device(self):
        if self.rank != 0:
            return
        dst_rank = (self.rank + 1) % self.world_size
        dst_worker_name = dist_utils.worker_name(dst_rank)

        for remote_module in self._create_remote_module_iter(
                "{}/cuda:0".format(dst_worker_name),
                modes=[ModuleCreationMode.MODULE_CTOR]):
            device = rpc.rpc_sync(dst_worker_name, remote_device,
                                  (remote_module.module_rref, ))
            self.assertEqual(device.type, "cuda")
            self.assertEqual(device.index, 0)

        # Test rank works as well.
        for remote_module in self._create_remote_module_iter(
                "rank:{}/cuda:0".format(dst_rank),
                modes=[ModuleCreationMode.MODULE_CTOR]):
            device = rpc.rpc_sync(dst_worker_name, remote_device,
                                  (remote_module.module_rref, ))
            self.assertEqual(device.type, "cuda")
            self.assertEqual(device.index, 0)
Exemplo n.º 29
0
    def test_create_script_module_on_remote(self):
        dst_name = worker_name((self.rank + 1) % self.world_size)
        # Construct on remote end with rpc_sync
        created_script_module = rpc.rpc_sync(dst_name,
                                             MyScriptModule,
                                             args=(self.rank, ))
        # Forward should output a ones tensor of self.rank.
        self.assertTrue(
            isinstance(created_script_module, torch.jit.ScriptModule))
        rank_ones_tensor = created_script_module()
        self.assertEqual(torch.ones(self.rank), rank_ones_tensor)

        # Construct ScriptModule with rpc.remote.
        remote_script_module = rpc.remote(dst_name,
                                          MyScriptModule,
                                          args=(self.rank, ))
        # Verify it is an instance of ScriptModule on remote end.
        remote_end_is_script = rpc.rpc_sync(
            remote_script_module.owner(),
            rref_isinstance,
            args=(remote_script_module, torch.jit.ScriptModule),
        )
        self.assertTrue(remote_end_is_script)
        # Run forward pass remotely.
        remote_forward_output = remote_script_module.rpc_sync().forward()
        self.assertEqual(remote_forward_output, torch.ones(self.rank))
        # Run function defined on ScriptModule remotely.
        remote_func_output = remote_script_module.rpc_sync().custom_func()
        self.assertEqual(remote_func_output, torch.ones(self.rank))
        # Ensure we can transfer ScriptModule RRef to this rank and run
        # forward pass.
        local_script_module = remote_script_module.to_here()
        self.assertTrue(isinstance(local_script_module,
                                   torch.jit.ScriptModule))
        rank_ones_tensor = local_script_module()
        self.assertEqual(rank_ones_tensor, torch.ones(self.rank))
        local_script_func_output = local_script_module.custom_func()
        self.assertEqual(local_script_func_output, torch.ones(self.rank))
Exemplo n.º 30
0
    def test_torchscript_functions_not_supported(self):
        dst_worker_name = worker_name((self.rank + 1) % self.world_size)

        my_local_script_module = MyScriptModule(self.rank)

        # It is not thread safe to instantiate MyScriptModule in multiple threads,
        # wait for local MyScriptModule instantiation to finish,
        # otherwise it could instantiate MyScriptModule in parallel with
        # server thread in the below
        initialize_pg(self.file_init_method, self.rank, self.world_size)
        dist.barrier()

        # rpc_sync still accepts script class and run it in
        # the same code path as python call.
        ret = rpc.rpc_sync(dst_worker_name, MyScriptClass, args=(self.rank, ))

        # rpc_sync does not accept script module method.
        # Python 3.5 and Python 3.6 throw different error message, the only
        # common word can be greped is "pickle".
        with self.assertRaisesRegex(TypeError, "pickle"):
            ret = rpc.rpc_async(dst_worker_name,
                                my_local_script_module.forward,
                                args=())