Ejemplo n.º 1
0
    def test_self_add(self):
        self_worker_id = dist.get_worker_id()
        self_worker_name = "worker{}".format(self.rank)

        with self.assertRaisesRegex(
            RuntimeError, "does not support making RPC calls to self"
        ):
            dist.rpc_sync(self_worker_id, torch.add, args=(torch.ones(2, 2), 1))

        with self.assertRaisesRegex(
            RuntimeError, "does not support making RPC calls to self"
        ):
            dist.rpc_sync(self_worker_name, torch.add, args=(torch.ones(2, 2), 1))
Ejemplo n.º 2
0
 def test_py_built_in(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     ret = dist.rpc_sync("worker{}".format(dst_rank),
                         min,
                         args=(n, n + 1, n + 2))
     self.assertEqual(ret, min(n, n + 1, n + 2))
Ejemplo n.º 3
0
 def test_scalar_add(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     ret = dist.rpc_sync(
         "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), n)
     )
     self.assertEqual(ret, (torch.ones(n, n) + n))
Ejemplo n.º 4
0
 def test_py_class_instance_method(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     ret = dist.rpc_sync(
         "worker{}".format(dst_rank), my_class(2).my_instance_method, args=(n,)
     )
     self.assertEqual(ret, my_class(2).my_instance_method(n))
Ejemplo n.º 5
0
 def test_py_class_static_method(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     ret = dist.rpc_sync(
         "worker{}".format(dst_rank), my_class.my_static_method, args=(n + 10,)
     )
     self.assertEqual(ret, my_class.my_static_method(n + 10))
Ejemplo n.º 6
0
 def test_py_function_exception(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     with self.assertRaisesRegex(Exception, "TypeError"):
         ret = dist.rpc_sync("worker{}".format(dst_rank),
                             no_result,
                             args=(10, ))
Ejemplo n.º 7
0
 def test_nonzero(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     x = torch.ones(self.world_size, self.world_size)
     x[self.rank][self.rank] = 0
     ret = dist.rpc_sync("worker{}".format(dst_rank), torch.nonzero, args=(x,))
     self.assertEqual(ret, x.nonzero())
Ejemplo n.º 8
0
 def test_sync_rpc(self):
     dst_rank = (self.rank + 1) % self.world_size
     for i in range(20):
         dist.sync_rpc()
         n = i + self.rank + 1
         ret1 = dist.rpc_sync(
             "worker{}".format(dst_rank),
             torch.add,
             args=(torch.ones(n, n), torch.ones(n, n)),
         )
         dist.sync_rpc()
         ret2 = dist.rpc_sync(
             "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
         )
         dist.sync_rpc()
         self.assertEqual(ret1, torch.ones(n, n) * 2)
         self.assertEqual(ret2, torch.ones(n, n) * 3)
Ejemplo n.º 9
0
    def test_add_with_id(self):
        n = self.rank + 1
        dst_rank = n % self.world_size
        workder_id = dist.get_worker_id("worker{}".format(dst_rank))

        ret = dist.rpc_sync(
            workder_id, torch.add, args=(torch.ones(n, n), torch.ones(n, n))
        )
        self.assertEqual(ret, torch.ones(n, n) * 2)
Ejemplo n.º 10
0
 def test_py_user_defined(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     ret = dist.rpc_sync(
         "worker{}".format(dst_rank),
         my_function,
         kwargs={"a": n, "b": n + 1, "c": n + 2},
     )
     self.assertEqual(ret, my_function(n, n + 1, n + 2))
Ejemplo n.º 11
0
 def test_py_tensors(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     ret = dist.rpc_sync(
         "worker{}".format(dst_rank),
         my_tensor_function,
         args=(torch.ones(n, n), torch.ones(n, n)),
     )
     self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n)))
Ejemplo n.º 12
0
 def test_nested_rpc(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     ret = dist.rpc_sync(
         "worker{}".format(dst_rank),
         nested_rpc,
         args=("worker{}".format(self.rank),),
     )
     self.assertEqual(ret, torch.ones(2, 2) + 1)
Ejemplo n.º 13
0
 def test_multi_rpc(self):
     dst_rank = (self.rank + 1) % self.world_size
     for i in range(20):
         n = i + self.rank + 1
         ret = dist.rpc_sync(
             "worker{}".format(dst_rank),
             torch.add,
             args=(torch.ones(n, n), torch.ones(n, n)),
         )
         self.assertEqual(ret, torch.ones(n, n) * 2)
Ejemplo n.º 14
0
 def test_py_tensors_in_container(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     a = [torch.ones(n, n), torch.ones(n, n)]
     b = TensorClass(build_complex_tensors())
     c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)}
     ret = dist.rpc_sync(
         "worker{}".format(dst_rank), my_complex_tensor_function, args=(a, b, c)
     )
     self.assertEqual(ret, my_complex_tensor_function(a, b, c))
Ejemplo n.º 15
0
 def test_rpc_return_rref(self):
     n = self.rank + 1
     dst_rank1 = n % self.world_size
     dst_rank2 = (n + 1) % self.world_size
     rref = dist.rpc_sync(
         "worker{}".format(dst_rank1),
         rpc_return_rref,
         args=("worker{}".format(dst_rank2), ),
     )
     self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
Ejemplo n.º 16
0
    def test_join_rpc(self):
        n = self.rank + 1
        dst_rank = n % self.world_size
        ret = dist.rpc_sync(
            "worker{}".format(dst_rank),
            torch.add,
            args=(torch.ones(n, n), torch.ones(n, n)),
        )
        self.assertEqual(ret, torch.ones(n, n) * 2)
        dist.join_rpc()

        with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
            dist.rpc_sync(
                "worker{}".format(dst_rank),
                torch.add,
                args=(torch.ones(n, n), torch.ones(n, n)),
            )

        # it's safe to call join_rpc() multiple times
        dist.join_rpc()
Ejemplo n.º 17
0
    def test_py_nested_pickle(self):
        n = self.rank + 1
        dst_rank = n % self.world_size

        ret = dist.rpc_sync(
            "worker{}".format(dst_rank),
            run_nested_pickle,
            args=(MyPickleClass(), torch.ones(2, 2)),
        )

        m = MyPickleClass()
        m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2)))
        self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2)))
Ejemplo n.º 18
0
    def test_py_rpc_rref_args(self):
        n = self.rank + 1
        dst_rank = n % self.world_size
        rref_a = dist.remote("worker{}".format(dst_rank),
                             my_function,
                             args=(torch.ones(n, n), 2, 0))
        rref_b = dist.remote("worker{}".format(dst_rank),
                             my_function,
                             args=(torch.ones(n, n), 1, 0))

        c = dist.rpc_sync("worker{}".format(dst_rank),
                          my_rref_function,
                          args=(rref_a, rref_b))

        self.assertEqual(c, torch.ones(n, n) + 4)
Ejemplo n.º 19
0
    def test_rpc_complex_args(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            num_tensors = 10
            tensors = []
            for i in range(num_tensors):
                tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0)))
            ret = dist.rpc_sync('worker{}'.format(dst_rank),
                                torch.stack,
                                args=(tensors, ))
            self.assertEqual(torch.stack(tensors), ret)

            # Verify appropriate tensors have been attached the autograd graph.
            next_funcs = dist_autograd._current_context()._send_functions(
            )[0].next_functions
            idx = 0
            for i in range(num_tensors):
                if i % 2 == 0:
                    self.assertEqual('torch::autograd::AccumulateGrad',
                                     next_funcs[i][0].name())
                    self.assertEqual(tensors[i], next_funcs[i][0].variable)
                else:
                    self.assertIsNone(next_funcs[i][0])
Ejemplo n.º 20
0
    def test_autograd_send_function(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            ret = dist.rpc_sync('worker{}'.format(dst_rank),
                                torch.add,
                                args=(t1, t2))

            # Get send function.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))

            # Retrieve the next functions in the graph.
            next_funcs = send_functions[0].next_functions
            self.assertEqual(2, len(next_funcs))

            # We should now hit t1 and t2 in the autograd graph.
            self.assertEqual('torch::autograd::AccumulateGrad',
                             next_funcs[0][0].name())
            self.assertEqual(t1, next_funcs[0][0].variable)
            self.assertEqual(0, next_funcs[0][1])
            self.assertEqual('torch::autograd::AccumulateGrad',
                             next_funcs[1][0].name())
            self.assertEqual(t2, next_funcs[1][0].variable)
            self.assertEqual(0, next_funcs[1][1])

        # autograd context should be cleaned up by now.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._retrieve_context(context_id)

        # No autograd context available.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._current_context()
Ejemplo n.º 21
0
    def test_autograd_functions(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            ret = dist.rpc_sync("worker{}".format(dst_rank),
                                torch.add,
                                args=(t1, t2))
            dist.rpc_sync("worker{}".format(dst_rank),
                          _set_rpc_done,
                          args=(context_id, ))

            # Get send function.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))

            # Retrieve the next functions in the graph.
            next_funcs = list(send_functions.values())[0].next_functions
            self.assertEqual(2, len(next_funcs))

            # We should now hit t1 and t2 in the autograd graph.
            self.assertEqual("torch::autograd::AccumulateGrad",
                             next_funcs[0][0].name())
            self.assertEqual(t1, next_funcs[0][0].variable)
            self.assertEqual(0, next_funcs[0][1])
            self.assertEqual("torch::autograd::AccumulateGrad",
                             next_funcs[1][0].name())
            self.assertEqual(t2, next_funcs[1][0].variable)
            self.assertEqual(0, next_funcs[1][1])

            # Test recv functions.
            recv_functions = ctx._recv_functions()
            self.assertEqual(1, len(recv_functions))
            self.assertEqual(ret.grad_fn, list(recv_functions.values())[0])

            # We should have send/recv functions from the previous rank, get all
            # contexts in this node to find them.

            # Wait for the prev rank to be done with rpc.
            while not prev_rank_rpc_done:
                time.sleep(0.1)
                pass

            # Now verify the autograd graph.
            ctx = dist_autograd._retrieve_context(prev_rank_context_id)

            # Get the send function.
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))

            # Verify next function is AddBackward0
            next_funcs = list(send_functions.values())[0].next_functions
            self.assertEqual(1, len(next_funcs))
            add_backward_fn = next_funcs[0][0]
            self.assertEqual("AddBackward0", add_backward_fn.name())

            # Verify the next two functions are the same recv backward function.
            next_funcs = add_backward_fn.next_functions
            self.assertEqual(2, len(next_funcs))
            self.assertEqual("torch::distributed::autograd::RecvRpcBackward",
                             next_funcs[0][0].name())
            self.assertEqual("torch::distributed::autograd::RecvRpcBackward",
                             next_funcs[1][0].name())
            self.assertEqual(next_funcs[0][0], next_funcs[1][0])

        # autograd context should be cleaned up by now.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._retrieve_context(context_id)

        # No autograd context available.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._current_context()
Ejemplo n.º 22
0
 def test_py_no_return_result(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     ret = dist.rpc_sync("worker{}".format(dst_rank), no_result)
     self.assertEqual(ret, no_result())
Ejemplo n.º 23
0
 def test_expected_src(self):
     dst_rank = (self.rank + 1) % self.world_size
     expected_src_rank = (self.rank - 1) % self.world_size
     ret = dist.rpc_sync("worker{}".format(dst_rank), set_value, args=(self.rank,))
     value = VALUE_FUTURE.result()
     self.assertEqual(value, expected_src_rank)
Ejemplo n.º 24
0
 def test_py_class_constructor(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     ret = dist.rpc_sync("worker{}".format(dst_rank), my_class, args=(n,))
     self.assertEqual(ret.a, n)
Ejemplo n.º 25
0
def nested_rpc(dst):
    return dist.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))