예제 #1
0
    def test_rpc_complex_args(self):
        with dist_autograd.context() as context_id:
            num_tensors = 10
            tensors = []
            for i in range(num_tensors):
                tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0)))
            ret = rpc.rpc_sync(
                "worker{}".format(self._next_rank()), torch.stack, args=(tensors,)
            )
            self.assertEqual(torch.stack(tensors), ret)

            # Verify appropriate tensors have been attached the autograd graph.
            next_funcs = list(
                dist_autograd._current_context()._send_functions().values()
            )[0].next_functions
            idx = 0
            for i in range(num_tensors):
                if i % 2 == 0:
                    self.assertEqual(
                        "torch::autograd::AccumulateGrad", next_funcs[i][0].name()
                    )
                    self.assertEqual(tensors[i], next_funcs[i][0].variable)
                else:
                    self.assertIsNone(next_funcs[i][0])

            # Verify that the worker id has been recorded in the context
            ctx = dist_autograd._current_context()
            worker_ids = ctx._known_worker_ids()
            self.assertEqual(len(worker_ids), 1)
            dst_rank = (self.rank + 1) % self.world_size
            self.assertEqual(worker_ids[0], dst_rank)
예제 #2
0
    def test_worker_ids_recorded(self):
        dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
        with dist_autograd.context() as context_id:
            # if no tensors require grad, we do not add the send functions, so
            # no worker ids should be recorded.
            t1 = torch.ones(3, 3, requires_grad=False)
            t2 = torch.zeros(3, 3, requires_grad=False)
            for dst_rank in dst_ranks:
                ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
                rpc.rpc_sync(
                    "worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)
                )
            # no worker ids should be recorded.
            ctx = dist_autograd._current_context()
            worker_ids = ctx._known_worker_ids()
            self.assertEqual(len(worker_ids), 0)

            # worker_ids should be recorded when tensors do require grad
            t1.requires_grad = True
            t2.requires_grad = True
            for dst_rank in dst_ranks:
                ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
                rpc.rpc_sync(
                    "worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)
                )
            # all worker_ids in dst_ranks should be recorded.
            worker_ids = ctx._known_worker_ids()
            self.assertEqual(len(worker_ids), len(dst_ranks))
            self.assertEqual(set(worker_ids), dst_ranks)
예제 #3
0
    def test_graph_for_py_nested_call_itself(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            ret = rpc.rpc_sync("worker{}".format(dst_rank),
                               my_py_nested_call,
                               args=(t1, t2, (self.rank - 1 + self.world_size) % self.world_size, self.world_size, 0))
            rpc.rpc_sync("worker{}".format((self.rank + 1) % self.world_size),
                         _set_rpc_done, args=(context_id, 1))

            # For self.rank, it has 2 graphs to verify.
            # One is for current context id when this rank send first rpc
            # call and execute the torch.add() operator.
            # Another one is for prev context id when this rank make
            # nested call.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(2, len(send_functions))
            recv_functions = ctx._recv_functions()
            self.assertEqual(2, len(recv_functions))
            self._verify_graph_for_first_rpc_call(list(send_functions.values())[0],
                                                  list(recv_functions.values())[1],
                                                  t1, t2, ret)
            self._verify_graph_for_rpc_call_exec(list(send_functions.values())[1])

            # Verify two pairs of send and recv functions for nested
            # call
            self._check_rpc_done(1)
            ctx = dist_autograd._retrieve_context(ctx_ids[1])
            self._verify_graph_for_nested_rpc_call(ctx)
            # this barrier is needed so one worker does not clean up their
            # autograd context before another worker tries to access it.
            dist.barrier()
예제 #4
0
    def test_graph_for_py_nested_call(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            nest_dst_rank = (dst_rank + 1) % self.world_size
            ret = rpc.rpc_sync("worker{}".format(dst_rank),
                               my_py_nested_call,
                               args=(t1, t2, dst_rank, self.world_size, 1))
            for rd in [1, 2, 3]:
                rpc.rpc_sync("worker{}".format(
                    (self.rank + rd) % self.world_size),
                             _set_rpc_done,
                             args=(context_id, rd))

            # For self.rank, it has 4 graphs to verify
            # One is for current context id when this rank send first rpc call.
            # Second one is for prev context id when this rank make 1st nested
            # call.
            # Third one is for prev prev context id when this rank make
            # 2nd nested call.
            # Last one is for prev prev prev context id when this rank
            # execute the torch.add() operator.

            # Verify first graph for current context id.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))
            recv_functions = ctx._recv_functions()
            self.assertEqual(1, len(recv_functions))
            self._verify_graph_for_first_rpc_call(
                list(send_functions.values())[0],
                list(recv_functions.values())[0], t1, t2, ret)

            # Verify second graph for 1st nested call.
            self._check_rpc_done(1)
            ctx = dist_autograd._retrieve_context(ctx_ids[1])
            self._verify_graph_for_nested_rpc_call(ctx)

            # Verify third graph for 2nd nested call.
            self._check_rpc_done(2)
            ctx = dist_autograd._retrieve_context(ctx_ids[2])
            self._verify_graph_for_nested_rpc_call(ctx)

            # verify last graph for rpc call execution.
            self._check_rpc_done(3)
            ctx = dist_autograd._retrieve_context(ctx_ids[3])
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))
            self._verify_graph_for_rpc_call_exec(
                list(send_functions.values())[0])
            # this barrier is needed so one worker does not clean up their
            # autograd context before another worker tries to access it.
            dist.barrier()
예제 #5
0
    def _test_graph(self, fn):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            ret = rpc.rpc_sync("worker{}".format(dst_rank), fn, args=(t1, t2))
            rpc.rpc_sync("worker{}".format(dst_rank),
                         _set_rpc_done,
                         args=(context_id, 1))

            # Verify graph for current context id.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))
            recv_functions = ctx._recv_functions()
            self.assertEqual(1, len(recv_functions))
            self._verify_graph_for_first_rpc_call(
                list(send_functions.values())[0],
                list(recv_functions.values())[0], t1, t2, ret)

            # Wait for the prev rank to be done with rpc.
            self._check_rpc_done(1)
            # Verify graph for previous context id.
            ctx = dist_autograd._retrieve_context(ctx_ids[1])
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))
            self._verify_graph_for_rpc_call_exec(
                list(send_functions.values())[0])
            # this barrier is needed so one worker does not clean up their
            # autograd context before another worker tries to access it.
            dist.barrier()

        # autograd context should be cleaned up by now.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._retrieve_context(context_id)

        # No autograd context available.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._current_context()
예제 #6
0
    def test_autograd_send_function(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            ret = dist.rpc_sync('worker{}'.format(dst_rank),
                                torch.add,
                                args=(t1, t2))

            # Get send function.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))

            # Retrieve the next functions in the graph.
            next_funcs = send_functions[0].next_functions
            self.assertEqual(2, len(next_funcs))

            # We should now hit t1 and t2 in the autograd graph.
            self.assertEqual('torch::autograd::AccumulateGrad',
                             next_funcs[0][0].name())
            self.assertEqual(t1, next_funcs[0][0].variable)
            self.assertEqual(0, next_funcs[0][1])
            self.assertEqual('torch::autograd::AccumulateGrad',
                             next_funcs[1][0].name())
            self.assertEqual(t2, next_funcs[1][0].variable)
            self.assertEqual(0, next_funcs[1][1])

        # autograd context should be cleaned up by now.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._retrieve_context(context_id)

        # No autograd context available.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._current_context()
예제 #7
0
    def step(self):
        """
        Performs a single optimization step.

        This will call :meth:`torch.optim.Optimizer.step` on each worker
        containing parameters to be optimized, and will block until all workers
        return. The current distributed autograd
        :class:`~torch.distributed.autograd.context` will be used globally.
        """
        autograd_ctx_id = dist_autograd._current_context()._context_id()
        rpc_futs = []
        for optim in self.remote_optimizers:
            rpc_futs.append(rpc.rpc_async(
                optim.owner(),
                _local_optimizer_step,
                args=(optim, autograd_ctx_id),
            ))
        _wait_for_all(rpc_futs)
예제 #8
0
    def test_no_graph_with_tensors_not_require_grad(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=False)
            t2 = torch.zeros(3, 3, requires_grad=False)
            ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
            rpc.rpc_sync("worker{}".format(dst_rank),
                         _set_rpc_done, args=(context_id, 1))

            ctx = dist_autograd._current_context()
            send_functions = ctx._send_functions()
            self.assertEqual(len(send_functions), 0)
            recv_functions = ctx._recv_functions()
            self.assertEqual(len(recv_functions), 0)

            # Wait for the prev rank to be done with rpc.
            self._check_rpc_done(1)
            # prev context id is not passed over as tensors do not require grads
            with self.assertRaises(RuntimeError):
                ctx = dist_autograd._retrieve_context(ctx_ids[1])
예제 #9
0
    def _test_no_graph_with_tensors_not_require_grad(self, exec_mode):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=False)
            t2 = torch.zeros(3, 3, requires_grad=False)
            if ExecMode.RPC_SYNC == exec_mode:
                ret = rpc.rpc_sync("worker{}".format(dst_rank),
                                   torch.add,
                                   args=(t1, t2))
            elif ExecMode.REMOTE == exec_mode:
                ret = rpc.remote("worker{}".format(dst_rank),
                                 torch.add,
                                 args=(t1, t2)).to_here().wait()
            else:
                raise ValueError("Unrecognized ExecMode {}".format(exec_mode))

            rpc.rpc_sync("worker{}".format(dst_rank),
                         _set_rpc_done,
                         args=(context_id, 1))

            ctx = dist_autograd._current_context()
            send_functions = ctx._send_functions()
            self.assertEqual(len(send_functions), 0)
            recv_functions = ctx._recv_functions()
            self.assertEqual(len(recv_functions), 0)

            # Wait for the prev rank to be done with rpc.
            self._check_rpc_done(1)
            if ExecMode.RPC_SYNC == exec_mode:
                # prev context id is not passed over as tensors do not require
                # grads
                with self.assertRaises(RuntimeError):
                    ctx = dist_autograd._retrieve_context(ctx_ids[1])
            elif ExecMode.REMOTE == exec_mode:
                # NB: RRef.to_here() always passes the autograd context to the
                # the callee, as the caller does not know whether the return
                # value would contain a requires_grad tensor or not.
                pass
예제 #10
0
    def test_rpc_complex_args(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            num_tensors = 10
            tensors = []
            for i in range(num_tensors):
                tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0)))
            ret = dist.rpc_sync('worker{}'.format(dst_rank),
                                torch.stack,
                                args=(tensors, ))
            self.assertEqual(torch.stack(tensors), ret)

            # Verify appropriate tensors have been attached the autograd graph.
            next_funcs = dist_autograd._current_context()._send_functions(
            )[0].next_functions
            idx = 0
            for i in range(num_tensors):
                if i % 2 == 0:
                    self.assertEqual('torch::autograd::AccumulateGrad',
                                     next_funcs[i][0].name())
                    self.assertEqual(tensors[i], next_funcs[i][0].variable)
                else:
                    self.assertIsNone(next_funcs[i][0])
예제 #11
0
    def test_autograd_functions(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            ret = rpc.rpc_sync("worker{}".format(dst_rank),
                               torch.add,
                               args=(t1, t2))
            rpc.rpc_sync("worker{}".format(dst_rank),
                         _set_rpc_done,
                         args=(context_id, ))

            # Get send function.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))

            # Retrieve the next functions in the graph.
            next_funcs = list(send_functions.values())[0].next_functions
            self.assertEqual(2, len(next_funcs))

            # We should now hit t1 and t2 in the autograd graph.
            self.assertEqual("torch::autograd::AccumulateGrad",
                             next_funcs[0][0].name())
            self.assertEqual(t1, next_funcs[0][0].variable)
            self.assertEqual(0, next_funcs[0][1])
            self.assertEqual("torch::autograd::AccumulateGrad",
                             next_funcs[1][0].name())
            self.assertEqual(t2, next_funcs[1][0].variable)
            self.assertEqual(0, next_funcs[1][1])

            # Test recv functions.
            recv_functions = ctx._recv_functions()
            self.assertEqual(1, len(recv_functions))
            self.assertEqual(ret.grad_fn, list(recv_functions.values())[0])

            # We should have send/recv functions from the previous rank, get all
            # contexts in this node to find them.

            # Wait for the prev rank to be done with rpc.
            while not prev_rank_rpc_done:
                time.sleep(0.1)
                pass

            # Now verify the autograd graph.
            ctx = dist_autograd._retrieve_context(prev_rank_context_id)

            # Get the send function.
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))

            # Verify next function is AddBackward0
            next_funcs = list(send_functions.values())[0].next_functions
            self.assertEqual(1, len(next_funcs))
            add_backward_fn = next_funcs[0][0]
            self.assertEqual("AddBackward0", add_backward_fn.name())

            # Verify the next two functions are the same recv backward function.
            next_funcs = add_backward_fn.next_functions
            self.assertEqual(2, len(next_funcs))
            self.assertEqual("torch::distributed::autograd::RecvRpcBackward",
                             next_funcs[0][0].name())
            self.assertEqual("torch::distributed::autograd::RecvRpcBackward",
                             next_funcs[1][0].name())
            self.assertEqual(next_funcs[0][0], next_funcs[1][0])

        # autograd context should be cleaned up by now.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._retrieve_context(context_id)

        # No autograd context available.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._current_context()
예제 #12
0
    def _test_graph_for_py_nested_call_itself(self, exec_mode):
        dst_rank = (self.rank + 1) % self.world_size

        # This is for the below `dist.barrier`.
        # For `RpcAgent` other than `ProcessGroupAgent`,
        # no `_default_pg` is initialized.
        if not dist.is_initialized():
            dist.init_process_group(
                backend="gloo",
                init_method=self.init_method,
                rank=self.rank,
                world_size=self.world_size,
            )

        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            if ExecMode.RPC_SYNC == exec_mode:
                ret = rpc.rpc_sync("worker{}".format(dst_rank),
                                   my_py_nested_call,
                                   args=(t1, t2,
                                         (self.rank - 1 + self.world_size) %
                                         self.world_size, self.world_size, 0))
            elif ExecMode.REMOTE == exec_mode:
                ret = rpc.remote(
                    "worker{}".format(dst_rank),
                    my_py_nested_call,
                    args=(t1, t2,
                          (self.rank - 1 + self.world_size) % self.world_size,
                          self.world_size, 0)).to_here().wait()
            else:
                raise ValueError("Unrecognized ExecMode {}".format(exec_mode))

            rpc.rpc_sync("worker{}".format((self.rank + 1) % self.world_size),
                         _set_rpc_done,
                         args=(context_id, 1))

            # For self.rank, it has 2 graphs to verify.
            # One is for current context id when this rank send first rpc
            # call and execute the torch.add() operator.
            # Another one is for prev context id when this rank make
            # nested call.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(2, len(send_functions))
            recv_functions = ctx._recv_functions()
            self.assertEqual(2, len(recv_functions))
            self._verify_graph_for_first_rpc_call(
                list(send_functions.values())[0],
                list(recv_functions.values())[1], t1, t2, ret)
            self._verify_graph_for_rpc_call_exec(
                list(send_functions.values())[1])

            # Verify two pairs of send and recv functions for nested
            # call
            self._check_rpc_done(1)
            ctx = dist_autograd._retrieve_context(ctx_ids[1])
            self._verify_graph_for_nested_rpc_call(ctx)
            # this barrier is needed so one worker does not clean up their
            # autograd context before another worker tries to access it.
            dist.barrier()
예제 #13
0
    def _test_graph(self, fn, exec_mode):
        dst_rank = (self.rank + 1) % self.world_size

        # This is for the below `dist.barrier`.
        # For `RpcAgent` other than `ProcessGroupAgent`,
        # no `_default_pg` is initialized.
        if not dist.is_initialized():
            dist.init_process_group(
                backend="gloo",
                init_method=self.init_method,
                rank=self.rank,
                world_size=self.world_size,
            )

        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            if ExecMode.RPC_SYNC == exec_mode:
                ret = rpc.rpc_sync("worker{}".format(dst_rank),
                                   fn,
                                   args=(t1, t2))
            elif ExecMode.REMOTE == exec_mode:
                ret = rpc.remote("worker{}".format(dst_rank),
                                 fn,
                                 args=(t1, t2)).to_here().wait()
            else:
                raise ValueError("Unrecognized ExecMode {}".format(exec_mode))

            rpc.rpc_sync("worker{}".format(dst_rank),
                         _set_rpc_done,
                         args=(context_id, 1))

            # Verify graph for current context id.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))
            recv_functions = ctx._recv_functions()
            self.assertEqual(1, len(recv_functions))
            self._verify_graph_for_first_rpc_call(
                list(send_functions.values())[0],
                list(recv_functions.values())[0], t1, t2, ret)

            # Wait for the prev rank to be done with rpc.
            self._check_rpc_done(1)
            # Verify graph for previous context id.
            ctx = dist_autograd._retrieve_context(ctx_ids[1])
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))
            self._verify_graph_for_rpc_call_exec(
                list(send_functions.values())[0])
            # this barrier is needed so one worker does not clean up their
            # autograd context before another worker tries to access it.
            dist.barrier()

        # autograd context should be cleaned up by now.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._retrieve_context(context_id)

        # No autograd context available.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._current_context()