Exemplo n.º 1
0
 def init_rpc_connection(self, global_rank: int, world_size: int) -> None:
     os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000')
     rpc.init_rpc(f"worker{global_rank}",
                  rank=global_rank,
                  world_size=world_size)
     rpc._set_rpc_timeout(self.rpc_timeout_sec)
     self._is_rpc_initialized = True
Exemplo n.º 2
0
    def test_timeout_in_torchscript_function(self):
        # Call rpc_async + fut.wait() in torchscript function and ensure that
        # timeout is raised.
        if self.rank != 0:
            return

        dst_worker_name = worker_name((self.rank + 1) % self.world_size)

        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
        kwargs = {
            "first_kwarg": torch.tensor([2, 2]),
            "second_kwarg": torch.tensor([3, 3]),
        }
        expected_error = self.get_timeout_error_regex()
        # Ensure that we get a timeout if we override the default timeout and
        # the RPC takes longer to execute.
        with self.assertRaisesRegex(RuntimeError, expected_error):
            rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0.5)

        # Ensure that we timeout if we don't specify a timeout but the default
        # is less than the RPC takes to execute.
        rpc._set_rpc_timeout(0.001)
        with self.assertRaisesRegex(RuntimeError, expected_error):
            script_rpc_async_call(
                dst_worker_name, args, kwargs
            )

        # Ensure that we run to completion if zero timeout is specified.
        ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0)
        self.assertEqual(ret, torch.tensor([8, 8]))
        # reset for clean shutdown
        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
Exemplo n.º 3
0
    def test_rpc_timeouts(self):
        dst_rank = (self.rank + 1) % self.world_size
        rpc._set_rpc_timeout(timedelta(milliseconds=1))
        # futures should time out and be marked with an exception indicating it as such.
        futs = [
            rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=())
            for _ in range(10)
        ]
        for fut in futs:
            with self.assertRaisesRegex(RuntimeError, "RPC ran for more than"):
                fut.wait()

        # ensure that if a new timeout is set old futures don't time out but new ones do.
        rpc._set_rpc_timeout(timedelta(seconds=200))
        # create a longstanding RPC.
        fut1 = rpc.rpc_async("worker{}".format(dst_rank),
                             my_sleep_func,
                             args=(1, ))
        # now, set a short timeout.
        rpc._set_rpc_timeout(timedelta(milliseconds=1))
        # f2 should time out, f should not.
        fut2 = rpc.rpc_async("worker{}".format(dst_rank),
                             my_sleep_func,
                             args=(1, ))
        with self.assertRaises(RuntimeError):
            fut2.wait()
        fut1.wait()

        # future should run to completion if the timeout is zero.
        rpc._set_rpc_timeout(timedelta(seconds=0))
        rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func,
                      args=()).wait()

        # reset to default timeout so shutdown messages can process cleanly.
        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT)
Exemplo n.º 4
0
    def test_timeout_in_python(self):
        # Ensures timeouts are raised if we call rpc_async from within a
        # torchscript function, but wait on the future in python.
        if self.rank != 0:
            return

        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
        kwargs = {
            "first_kwarg": torch.tensor([2, 2]),
            "second_kwarg": torch.tensor([3, 3]),
        }
        expected_error = self.get_timeout_error_regex()

        fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0.5)
        with self.assertRaisesRegex(RuntimeError, expected_error):
            fut.wait()

        # Ensure timeout if we don't specify but the default is less than the
        # RPC takes to execute.
        rpc._set_rpc_timeout(0.001)
        fut = rpc_async_call_future_ret(dst_worker_name, args, kwargs)
        with self.assertRaisesRegex(RuntimeError, expected_error):
            fut.wait()

        # Ensure run to completion if zero timeout is specified
        fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0)
        result = fut.wait()
        self.assertEqual(result, torch.tensor([8, 8]))
        # reset for clean shutdown
        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
Exemplo n.º 5
0
def run_parameter_server(rank, world_size):
    # The parameter server just acts as a host for the model and responds to
    # requests from trainers.
    # rpc.shutdown() will wait for all workers to complete by default, which
    # in this case means that the parameter server will wait for all trainers
    # to complete, and then exit.
    print('PS master initializing RPC')
    rpc.init_rpc(name='parameter_server', rank=rank, world_size=world_size)
    rpc._set_rpc_timeout(timedelta(seconds=60))
    print('RPC initialized! Running parameter server...')
    rpc.shutdown(graceful=True)
    print('RPC shutdown on parameter server.')
Exemplo n.º 6
0
def run_worker(rank, world_size, epochs, batch_size, learning_rate, input_dim, hidden_dim, layer_dim, output_dim,
               train_set, validation_set, test_set):
    print(f'Worker rank {rank} initializing RPC')
    rpc.init_rpc(
        name=f'trainer_{rank}',
        rank=rank,
        world_size=world_size)

    rpc._set_rpc_timeout(timedelta(seconds=60))
    print(f'Worker {rank} done initializing RPC')

    model = WorkerNetwork(input_dim, hidden_dim, layer_dim, output_dim)
    worker = ParameterWorkerTrainer(rank, world_size, model, train_set, batch_size, learning_rate, validation_set,
                                    test_set)
    worker.train(epochs)
    rpc.shutdown()
Exemplo n.º 7
0
    def test_rpc_builtin_timeout(self):
        next_rank = (self.rank + 1) % self.world_size
        dst_worker = worker_name(next_rank)
        expected_error = self.get_timeout_error_regex()
        # PYTHON_CALL message types which correspond to Python UDF over RPC
        # by default get a delay (see faulty_rpc_agent_test_fixture)
        with self.assertRaisesRegex(RuntimeError, expected_error):
            rpc.rpc_sync(
                dst_worker,
                torch.add,
                args=(torch.tensor(1), torch.tensor(1)),
                timeout=1,
            )

        fut = rpc.rpc_async(
            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1
        )
        with self.assertRaisesRegex(RuntimeError, expected_error):
            fut.wait()

        # Ensure that the currently set default timeout is large enough such
        # that RPCs with delays still complete.
        fut = rpc.rpc_async(
            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
        )
        fut.wait()

        # Ensure timeout if we set a new default and don't override
        rpc._set_rpc_timeout(0.001)
        fut = rpc.rpc_async(
            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
        )
        with self.assertRaisesRegex(RuntimeError, expected_error):
            fut.wait()

        # Ensure run to completion if we specify timeout of 0
        fut = rpc.rpc_async(
            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=0
        )
        fut.wait()
        # Reset for clean shutdown
        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
Exemplo n.º 8
0
    def test_rpc_script_timeout(self):
        next_rank = (self.rank + 1) % self.world_size
        dst_worker = worker_name(next_rank)
        expected_error = self.get_timeout_error_regex()
        with self.assertRaisesRegex(RuntimeError, expected_error):
            rpc.rpc_sync(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1)

        fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1)
        with self.assertRaisesRegex(RuntimeError, expected_error):
            fut.wait()

        # Ensure that the currently set default timeout is large enough such
        # that RPCs with delays still complete.
        fut = rpc.rpc_async(
            dst_worker, my_script_func, args=(torch.tensor(1),)
        )
        fut.wait()

        # Ensure timeout if we set a new default and don't override
        rpc._set_rpc_timeout(0.001)
        fut = rpc.rpc_async(
            dst_worker, my_script_func, args=(torch.tensor(1),)
        )
        with self.assertRaisesRegex(RuntimeError, expected_error):
            fut.wait()

        # Ensure run to completion if we specify timeout of 0
        rpc._set_rpc_timeout(0.001)
        fut = rpc.rpc_async(
            dst_worker, my_script_func, args=(torch.tensor(1),), timeout=0
        )
        fut.wait()
        # Reset for clean shutdown
        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)