Example #1
0
    def test_graph_for_py_nested_call(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            nest_dst_rank = (dst_rank + 1) % self.world_size
            ret = rpc.rpc_sync("worker{}".format(dst_rank),
                               my_py_nested_call,
                               args=(t1, t2, dst_rank, self.world_size, 1))
            for rd in [1, 2, 3]:
                rpc.rpc_sync("worker{}".format(
                    (self.rank + rd) % self.world_size),
                             _set_rpc_done,
                             args=(context_id, rd))

            # For self.rank, it has 4 graphs to verify
            # One is for current context id when this rank send first rpc call.
            # Second one is for prev context id when this rank make 1st nested
            # call.
            # Third one is for prev prev context id when this rank make
            # 2nd nested call.
            # Last one is for prev prev prev context id when this rank
            # execute the torch.add() operator.

            # Verify first graph for current context id.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))
            recv_functions = ctx._recv_functions()
            self.assertEqual(1, len(recv_functions))
            self._verify_graph_for_first_rpc_call(
                list(send_functions.values())[0],
                list(recv_functions.values())[0], t1, t2, ret)

            # Verify second graph for 1st nested call.
            self._check_rpc_done(1)
            ctx = dist_autograd._retrieve_context(ctx_ids[1])
            self._verify_graph_for_nested_rpc_call(ctx)

            # Verify third graph for 2nd nested call.
            self._check_rpc_done(2)
            ctx = dist_autograd._retrieve_context(ctx_ids[2])
            self._verify_graph_for_nested_rpc_call(ctx)

            # verify last graph for rpc call execution.
            self._check_rpc_done(3)
            ctx = dist_autograd._retrieve_context(ctx_ids[3])
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))
            self._verify_graph_for_rpc_call_exec(
                list(send_functions.values())[0])
            # this barrier is needed so one worker does not clean up their
            # autograd context before another worker tries to access it.
            dist.barrier()
Example #2
0
def _all_contexts_cleaned_up(num_contexts, timeout_seconds=10):
    global known_context_ids
    start = time.time()
    context_id_to_raised = {}
    while time.time() - start < timeout_seconds:
        for context_id in known_context_ids:
            try:
                dist_autograd._retrieve_context(context_id)
            except RuntimeError:
                context_id_to_raised[context_id] = True
        if len(context_id_to_raised) == num_contexts:
            break
    # all contexts have been cleaned up if trying to retrieve any context resulted in a RuntimeError.
    success = len(context_id_to_raised) == num_contexts and all(context_id_to_raised.values())
    return success
Example #3
0
 def test_context_cleanup_many_workers(self):
     global known_context_ids
     dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
     with dist_autograd.context() as context_id:
         t1 = torch.ones(3, 3, requires_grad=True)
         t2 = torch.zeros(3, 3, requires_grad=True)
         for dst_rank in dst_ranks:
             ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
             rpc.rpc_sync("worker{}".format(dst_rank), _store_context_id, args=(context_id,))
     # the thread's context id should be cleaned up
     with self.assertRaises(RuntimeError):
         dist_autograd._retrieve_context(context_id)
     # check that all contexts have been cleaned up.
     success = _all_contexts_cleaned_up(num_contexts=len(dst_ranks))
     self.assertTrue(success)
Example #4
0
    def test_graph_for_py_nested_call_itself(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            ret = rpc.rpc_sync("worker{}".format(dst_rank),
                               my_py_nested_call,
                               args=(t1, t2, (self.rank - 1 + self.world_size) % self.world_size, self.world_size, 0))
            rpc.rpc_sync("worker{}".format((self.rank + 1) % self.world_size),
                         _set_rpc_done, args=(context_id, 1))

            # For self.rank, it has 2 graphs to verify.
            # One is for current context id when this rank send first rpc
            # call and execute the torch.add() operator.
            # Another one is for prev context id when this rank make
            # nested call.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(2, len(send_functions))
            recv_functions = ctx._recv_functions()
            self.assertEqual(2, len(recv_functions))
            self._verify_graph_for_first_rpc_call(list(send_functions.values())[0],
                                                  list(recv_functions.values())[1],
                                                  t1, t2, ret)
            self._verify_graph_for_rpc_call_exec(list(send_functions.values())[1])

            # Verify two pairs of send and recv functions for nested
            # call
            self._check_rpc_done(1)
            ctx = dist_autograd._retrieve_context(ctx_ids[1])
            self._verify_graph_for_nested_rpc_call(ctx)
            # this barrier is needed so one worker does not clean up their
            # autograd context before another worker tries to access it.
            dist.barrier()
Example #5
0
    def test_autograd_context(self):
        context_ids = []
        for i in range(1000):
            with dist_autograd.context() as context_id:
                self.assertEqual(
                    context_id,
                    dist_autograd._retrieve_context(context_id)._context_id())
                # First 16 bits should be worker_id.
                self.assertEqual(self.worker_id, context_id >> 48)
                context_ids.append(context_id)

        for context_id in context_ids:
            with self.assertRaisesRegex(
                    RuntimeError,
                    'Could not find autograd context with id: {}'.format(
                        context_id)):
                dist_autograd._retrieve_context(context_id)
Example #6
0
    def _test_graph(self, fn):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            ret = rpc.rpc_sync("worker{}".format(dst_rank), fn, args=(t1, t2))
            rpc.rpc_sync("worker{}".format(dst_rank),
                         _set_rpc_done,
                         args=(context_id, 1))

            # Verify graph for current context id.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))
            recv_functions = ctx._recv_functions()
            self.assertEqual(1, len(recv_functions))
            self._verify_graph_for_first_rpc_call(
                list(send_functions.values())[0],
                list(recv_functions.values())[0], t1, t2, ret)

            # Wait for the prev rank to be done with rpc.
            self._check_rpc_done(1)
            # Verify graph for previous context id.
            ctx = dist_autograd._retrieve_context(ctx_ids[1])
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))
            self._verify_graph_for_rpc_call_exec(
                list(send_functions.values())[0])
            # this barrier is needed so one worker does not clean up their
            # autograd context before another worker tries to access it.
            dist.barrier()

        # autograd context should be cleaned up by now.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._retrieve_context(context_id)

        # No autograd context available.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._current_context()
Example #7
0
    def test_autograd_context(self):
        # Verify max possible id.
        max_auto_increment = 281474976710655
        self.assertEqual(max_auto_increment + (self.worker_id << 48),
                         dist_autograd._get_max_id())

        context_ids = []
        for i in range(1000):
            with dist_autograd.context() as context_id:
                self.assertEqual(
                    context_id,
                    dist_autograd._retrieve_context(context_id)._context_id())
                # First 16 bits should be worker_id.
                self.assertEqual(self.worker_id, context_id >> 48)
                context_ids.append(context_id)

        for context_id in context_ids:
            with self.assertRaisesRegex(
                    RuntimeError,
                    'Could not find autograd context with id: {}'.format(
                        context_id)):
                dist_autograd._retrieve_context(context_id)
Example #8
0
    def test_no_graph_with_tensors_not_require_grad(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=False)
            t2 = torch.zeros(3, 3, requires_grad=False)
            ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
            rpc.rpc_sync("worker{}".format(dst_rank),
                         _set_rpc_done, args=(context_id, 1))

            ctx = dist_autograd._current_context()
            send_functions = ctx._send_functions()
            self.assertEqual(len(send_functions), 0)
            recv_functions = ctx._recv_functions()
            self.assertEqual(len(recv_functions), 0)

            # Wait for the prev rank to be done with rpc.
            self._check_rpc_done(1)
            # prev context id is not passed over as tensors do not require grads
            with self.assertRaises(RuntimeError):
                ctx = dist_autograd._retrieve_context(ctx_ids[1])
Example #9
0
    def _test_no_graph_with_tensors_not_require_grad(self, exec_mode):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=False)
            t2 = torch.zeros(3, 3, requires_grad=False)
            if ExecMode.RPC_SYNC == exec_mode:
                ret = rpc.rpc_sync("worker{}".format(dst_rank),
                                   torch.add,
                                   args=(t1, t2))
            elif ExecMode.REMOTE == exec_mode:
                ret = rpc.remote("worker{}".format(dst_rank),
                                 torch.add,
                                 args=(t1, t2)).to_here().wait()
            else:
                raise ValueError("Unrecognized ExecMode {}".format(exec_mode))

            rpc.rpc_sync("worker{}".format(dst_rank),
                         _set_rpc_done,
                         args=(context_id, 1))

            ctx = dist_autograd._current_context()
            send_functions = ctx._send_functions()
            self.assertEqual(len(send_functions), 0)
            recv_functions = ctx._recv_functions()
            self.assertEqual(len(recv_functions), 0)

            # Wait for the prev rank to be done with rpc.
            self._check_rpc_done(1)
            if ExecMode.RPC_SYNC == exec_mode:
                # prev context id is not passed over as tensors do not require
                # grads
                with self.assertRaises(RuntimeError):
                    ctx = dist_autograd._retrieve_context(ctx_ids[1])
            elif ExecMode.REMOTE == exec_mode:
                # NB: RRef.to_here() always passes the autograd context to the
                # the callee, as the caller does not know whether the return
                # value would contain a requires_grad tensor or not.
                pass
Example #10
0
    def test_autograd_send_function(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            ret = dist.rpc_sync('worker{}'.format(dst_rank),
                                torch.add,
                                args=(t1, t2))

            # Get send function.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))

            # Retrieve the next functions in the graph.
            next_funcs = send_functions[0].next_functions
            self.assertEqual(2, len(next_funcs))

            # We should now hit t1 and t2 in the autograd graph.
            self.assertEqual('torch::autograd::AccumulateGrad',
                             next_funcs[0][0].name())
            self.assertEqual(t1, next_funcs[0][0].variable)
            self.assertEqual(0, next_funcs[0][1])
            self.assertEqual('torch::autograd::AccumulateGrad',
                             next_funcs[1][0].name())
            self.assertEqual(t2, next_funcs[1][0].variable)
            self.assertEqual(0, next_funcs[1][1])

        # autograd context should be cleaned up by now.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._retrieve_context(context_id)

        # No autograd context available.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._current_context()
Example #11
0
    def test_autograd_functions(self):
        dst_rank = (self.rank + 1) % self.world_size
        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            ret = rpc.rpc_sync("worker{}".format(dst_rank),
                               torch.add,
                               args=(t1, t2))
            rpc.rpc_sync("worker{}".format(dst_rank),
                         _set_rpc_done,
                         args=(context_id, ))

            # Get send function.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))

            # Retrieve the next functions in the graph.
            next_funcs = list(send_functions.values())[0].next_functions
            self.assertEqual(2, len(next_funcs))

            # We should now hit t1 and t2 in the autograd graph.
            self.assertEqual("torch::autograd::AccumulateGrad",
                             next_funcs[0][0].name())
            self.assertEqual(t1, next_funcs[0][0].variable)
            self.assertEqual(0, next_funcs[0][1])
            self.assertEqual("torch::autograd::AccumulateGrad",
                             next_funcs[1][0].name())
            self.assertEqual(t2, next_funcs[1][0].variable)
            self.assertEqual(0, next_funcs[1][1])

            # Test recv functions.
            recv_functions = ctx._recv_functions()
            self.assertEqual(1, len(recv_functions))
            self.assertEqual(ret.grad_fn, list(recv_functions.values())[0])

            # We should have send/recv functions from the previous rank, get all
            # contexts in this node to find them.

            # Wait for the prev rank to be done with rpc.
            while not prev_rank_rpc_done:
                time.sleep(0.1)
                pass

            # Now verify the autograd graph.
            ctx = dist_autograd._retrieve_context(prev_rank_context_id)

            # Get the send function.
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))

            # Verify next function is AddBackward0
            next_funcs = list(send_functions.values())[0].next_functions
            self.assertEqual(1, len(next_funcs))
            add_backward_fn = next_funcs[0][0]
            self.assertEqual("AddBackward0", add_backward_fn.name())

            # Verify the next two functions are the same recv backward function.
            next_funcs = add_backward_fn.next_functions
            self.assertEqual(2, len(next_funcs))
            self.assertEqual("torch::distributed::autograd::RecvRpcBackward",
                             next_funcs[0][0].name())
            self.assertEqual("torch::distributed::autograd::RecvRpcBackward",
                             next_funcs[1][0].name())
            self.assertEqual(next_funcs[0][0], next_funcs[1][0])

        # autograd context should be cleaned up by now.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._retrieve_context(context_id)

        # No autograd context available.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._current_context()
Example #12
0
    def _test_graph_for_py_nested_call_itself(self, exec_mode):
        dst_rank = (self.rank + 1) % self.world_size

        # This is for the below `dist.barrier`.
        # For `RpcAgent` other than `ProcessGroupAgent`,
        # no `_default_pg` is initialized.
        if not dist.is_initialized():
            dist.init_process_group(
                backend="gloo",
                init_method=self.init_method,
                rank=self.rank,
                world_size=self.world_size,
            )

        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            if ExecMode.RPC_SYNC == exec_mode:
                ret = rpc.rpc_sync("worker{}".format(dst_rank),
                                   my_py_nested_call,
                                   args=(t1, t2,
                                         (self.rank - 1 + self.world_size) %
                                         self.world_size, self.world_size, 0))
            elif ExecMode.REMOTE == exec_mode:
                ret = rpc.remote(
                    "worker{}".format(dst_rank),
                    my_py_nested_call,
                    args=(t1, t2,
                          (self.rank - 1 + self.world_size) % self.world_size,
                          self.world_size, 0)).to_here().wait()
            else:
                raise ValueError("Unrecognized ExecMode {}".format(exec_mode))

            rpc.rpc_sync("worker{}".format((self.rank + 1) % self.world_size),
                         _set_rpc_done,
                         args=(context_id, 1))

            # For self.rank, it has 2 graphs to verify.
            # One is for current context id when this rank send first rpc
            # call and execute the torch.add() operator.
            # Another one is for prev context id when this rank make
            # nested call.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(2, len(send_functions))
            recv_functions = ctx._recv_functions()
            self.assertEqual(2, len(recv_functions))
            self._verify_graph_for_first_rpc_call(
                list(send_functions.values())[0],
                list(recv_functions.values())[1], t1, t2, ret)
            self._verify_graph_for_rpc_call_exec(
                list(send_functions.values())[1])

            # Verify two pairs of send and recv functions for nested
            # call
            self._check_rpc_done(1)
            ctx = dist_autograd._retrieve_context(ctx_ids[1])
            self._verify_graph_for_nested_rpc_call(ctx)
            # this barrier is needed so one worker does not clean up their
            # autograd context before another worker tries to access it.
            dist.barrier()
Example #13
0
    def _test_graph(self, fn, exec_mode):
        dst_rank = (self.rank + 1) % self.world_size

        # This is for the below `dist.barrier`.
        # For `RpcAgent` other than `ProcessGroupAgent`,
        # no `_default_pg` is initialized.
        if not dist.is_initialized():
            dist.init_process_group(
                backend="gloo",
                init_method=self.init_method,
                rank=self.rank,
                world_size=self.world_size,
            )

        with dist_autograd.context() as context_id:
            t1 = torch.ones(3, 3, requires_grad=True)
            t2 = torch.zeros(3, 3, requires_grad=True)
            if ExecMode.RPC_SYNC == exec_mode:
                ret = rpc.rpc_sync("worker{}".format(dst_rank),
                                   fn,
                                   args=(t1, t2))
            elif ExecMode.REMOTE == exec_mode:
                ret = rpc.remote("worker{}".format(dst_rank),
                                 fn,
                                 args=(t1, t2)).to_here().wait()
            else:
                raise ValueError("Unrecognized ExecMode {}".format(exec_mode))

            rpc.rpc_sync("worker{}".format(dst_rank),
                         _set_rpc_done,
                         args=(context_id, 1))

            # Verify graph for current context id.
            ctx = dist_autograd._current_context()
            self.assertEqual(context_id, ctx._context_id())
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))
            recv_functions = ctx._recv_functions()
            self.assertEqual(1, len(recv_functions))
            self._verify_graph_for_first_rpc_call(
                list(send_functions.values())[0],
                list(recv_functions.values())[0], t1, t2, ret)

            # Wait for the prev rank to be done with rpc.
            self._check_rpc_done(1)
            # Verify graph for previous context id.
            ctx = dist_autograd._retrieve_context(ctx_ids[1])
            send_functions = ctx._send_functions()
            self.assertEqual(1, len(send_functions))
            self._verify_graph_for_rpc_call_exec(
                list(send_functions.values())[0])
            # this barrier is needed so one worker does not clean up their
            # autograd context before another worker tries to access it.
            dist.barrier()

        # autograd context should be cleaned up by now.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._retrieve_context(context_id)

        # No autograd context available.
        with self.assertRaises(RuntimeError):
            ctx = dist_autograd._current_context()