Exemplo n.º 1
0
    def test_nested_remote(self):
        n = self.rank + 1
        dst_rank1 = n % self.world_size
        dst_rank2 = (n + 1) % self.world_size

        rref = rpc.remote(
            "worker{}".format(dst_rank1),
            nested_remote,
            args=("worker{}".format(dst_rank2), ),
        )
        self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3)
Exemplo n.º 2
0
    def _test_remote_message_delay_timeout(self, func, args, dst=None):
        if self.rank != 0:
            return
        # Test the case where remote message is eventually processed on the owner,
        # but the future on the creator times out before the response comes back.
        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
        dst_worker = "worker{}".format(dst_rank)
        # 10 ms timeout
        rref = rpc.remote(dst_worker, func, args=args, timeout=0.001)
        # Future corresponding to the remote creation should time out.
        expected_error = self.get_timeout_error_regex()
        with self.assertRaisesRegex(RuntimeError, expected_error):
            rref._get_future().wait()

        # Call to ensure pending callbacks are run.
        wait_until_pending_futures_and_users_flushed()
        # to_here() should now pick up that rpc.remote() creation has failed.
        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
            rref.to_here()

        # Test the case where rpc.remote() times out, but to_here() has already
        # started blocking before.
        # NOTE: we only test this when not sending to self, as to_here() calls
        # calls localValue(), which does not send an RPC and thus does not have
        # a timeout. This can be supported by allowing future.wait() to
        # take in an optional timeout (https://github.com/pytorch/pytorch/issues/39280)
        if dst_rank != self.rank:
            slow_rref = rpc.remote(dst_worker, func, args=args, timeout=2)

            with self.assertRaisesRegex(RuntimeError, expected_error):
                # to_here() should raise timeout error, since it does not know about the
                # status of rpc.remote().
                slow_rref.to_here(0.001)
        # Note: If we proceed with shutdown, UserRRef will send out a RRefUserDelete
        # but this can be a noop since it may not exist on the owner yet. Later,
        # the owner can process the RRef creation and wait for the delete message,
        # thus leading to a timeout.
        # Therefore, we wait until we get notification that pending owners have
        # been confirmed before sending out RRefUserDeletes.
        if dst_rank != self.rank:
            wait_until_owners_and_forks_on_rank(2, 2, rank=dst_rank)
Exemplo n.º 3
0
 def test_py_udf_remote(self):
     n = self.rank + 1
     dst_rank = n % self.world_size
     rref = rpc.remote(
         "worker{}".format(dst_rank),
         my_function,
         kwargs={
             "a": n,
             "b": n + 1,
             "c": n + 2
         },
     )
     self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
Exemplo n.º 4
0
def run_driver(rank, world_size, gpu_list, dataset, batch_size, lr,
               max_epoch, client_epoch, model, seed):
    exp_id = str(int(time.time()))
    print(f"Driver initializing RPC, rank {rank}, world size {world_size}")
    rpc.init_rpc(name="driver", rank=rank, world_size=world_size)
    print("Initialized driver")
    param_server_rref = rpc.remote("parameter_server", get_parameter_server,
                                   args=(gpu_list[0], world_size - 1, dataset, batch_size,
                                         lr, model, max_epoch, client_epoch, seed, exp_id))
    for _rank in range(1, world_size - 1):
        print(f"Driver registering worker node {_rank}")
        worker_server_rref = rpc.remote(f"trainer_{_rank}", get_worker,
                                        args=(gpu_list[_rank], _rank, world_size - 1, dataset,
                                              model, batch_size, lr, client_epoch, seed, exp_id))
        print(f"Driver binding worker {_rank} with param server")
        remote_method(ParameterServer.embedding_workers, param_server_rref, worker_server_rref)
        remote_method(TrainerNet.embedding_param_server, worker_server_rref, param_server_rref)

    fut = remote_method_async(ParameterServer.instruct_training, param_server_rref)
    fut.wait()
    rpc.shutdown()
    print("RPC shutdown on Driver")
Exemplo n.º 5
0
def init_workers(num_workers, h_dim, start, end, delta, isTe):
    kwargs = get_work_units(num_workers, start, end, delta, isTe)

    rrefs = []
    for i in range(len(kwargs)):
        rrefs.append(
            rpc.remote(
                'worker' + str(i),
                get_remote_gae,
                args=(h_dim, ld.load_lanl_dist, kwargs[i]),
            ))

    return rrefs
Exemplo n.º 6
0
 def test_nested_rref(self):
     n = self.rank + 1
     dst_rank1 = n % self.world_size
     dst_rank2 = (n + 1) % self.world_size
     rref_of_rrefs = rpc.remote(
         "worker{}".format(dst_rank1),
         nested_rref,
         args=("worker{}".format(dst_rank2),),
     )
     rrefs = rref_of_rrefs.to_here()
     self.assertEqual(len(rrefs), 2)
     self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
     self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
Exemplo n.º 7
0
    def test_rref_to_here_timeout_in_jit(self):
        if self.rank != 0:
            return

        dst_rank = (self.rank + 1) % self.world_size
        dst_worker = "worker{}".format(dst_rank)
        rref = rpc.remote(dst_worker,
                          torch.add,
                          args=(torch.tensor(1), torch.tensor(1)))
        expected_error = get_timeout_error_regex(
            dist_utils.TEST_CONFIG.rpc_backend_name)
        with self.assertRaisesRegex(RuntimeError, expected_error):
            rref_to_here_with_timeout(rref, 0.01)
 def create_clients(self, client_id_triple):
     for id, rank, world_size in client_id_triple:
         client = rpc.remote(id,
                             Client,
                             kwargs=dict(id=id,
                                         log_rref=self.log_rref,
                                         rank=rank,
                                         world_size=world_size,
                                         config=self.config))
         writer = SummaryWriter(
             f'{self.tb_path}/{self.config.experiment_prefix}_client_{id}')
         self.clients.append(
             ClientRef(id, client, tensorboard_writer=writer))
         self.client_data[id] = []
Exemplo n.º 9
0
 def test_rref_timeout_pickle_script_func(self):
     # Similar to above test, but calls python rpc with script function.
     if self.rank != 0:
         return
     dst_rank = (self.rank + 1) % self.world_size
     dst_worker = "worker{}".format(dst_rank)
     rref = rpc.remote(
         dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
     )
     # Will ensure error handling callbacks are run.
     wait_until_pending_futures_and_users_flushed()
     # Call RPC with script function that takes RRef, ensure timeout during pickling
     with self.assertRaisesRegex(RuntimeError, "RRef creation"):
         rpc.rpc_sync(dst_worker, rref_to_here, args=(rref, ))
Exemplo n.º 10
0
 def test_rref_timeout_pickle_in_jit(self):
     if self.rank != 0:
         return
     dst_rank = (self.rank + 1) % self.world_size
     dst_worker = "worker{}".format(dst_rank)
     rref = rpc.remote(
         dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
     )
     # Will ensure error handling callbacks are run.
     wait_until_pending_futures_and_users_flushed()
     # Call RPC with RRef arg in JIT, which will go through JIT pickling and
     # ensure error is raised.
     with self.assertRaisesRegex(RuntimeError, "RRef creation"):
         rpc_async_with_rref_arg(dst_worker, (rref, ))
Exemplo n.º 11
0
    def test_rref_as_arg_and_return(self):
        n = self.rank + 1
        dst_rank = n % self.world_size
        local_ret = one_arg(torch.ones(2, 2))

        # create rref on current rank
        rref = rpc.remote(worker_name(self.rank), one_arg, args=(torch.ones(2, 2),))

        # pass rref to another user in rpc call
        ret = rpc.rpc_sync(worker_name(dst_rank), rref_to_here, args=(rref,))
        self.assertEqual(ret, local_ret)

        # return rref in rpc call
        rref1 = rpc.rpc_sync(worker_name(dst_rank), return_rref, args=(rref,))
        self.assertEqual(rref1.to_here(), local_ret)

        # pass rref to another user in remote call
        rref2 = rpc.remote(worker_name(dst_rank), rref_to_here, args=(rref,))
        self.assertEqual(rref2.to_here(), local_ret)

        # return rref in remote call
        rref3 = rpc.remote(worker_name(dst_rank), return_rref, args=(rref,))
        self.assertEqual(rref3.to_here().to_here(), local_ret)
Exemplo n.º 12
0
 def _test_self_remote_rref_as_rpc_arg(self, dst):
     self_worker_info = rpc.get_worker_info()
     rref = rpc.remote(self_worker_info,
                       my_function,
                       args=(torch.ones(2, 2), 1, 3))
     fut = rpc.rpc_async(dst,
                         add_rref_to_value,
                         args=(rref, torch.ones(2, 2)))
     ret = rpc.rpc_sync(dst,
                        add_rref_to_value,
                        args=(rref, torch.ones(2, 2) + 1))
     self.assertEqual(ret, torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2) + 1)
     self.assertEqual(fut.wait(),
                      torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2))
Exemplo n.º 13
0
    def test_pass_local_rrefs(self):
        n = self.rank + 1
        dst_rank = n % self.world_size
        dst_worker = "worker{}".format(dst_rank)

        rref = RRef(40)
        self.assertEqual(
            rpc.rpc_sync(dst_worker, add_rref_to_value, args=(rref, 50)), 90)
        self.assertEqual(
            rpc.rpc_async(dst_worker, add_rref_to_value,
                          args=(rref, 50)).wait(), 90)
        self.assertEqual(
            rpc.remote(dst_worker, add_rref_to_value,
                       args=(rref, 50)).to_here(), 90)
Exemplo n.º 14
0
    def __init__(self, optimizer_class, params_rref, *args, **kwargs):
        per_worker_params_rref = defaultdict(list)
        for param in params_rref:
            per_worker_params_rref[param.owner()].append(param)

        self.remote_optimizers = []
        for worker, param_rrefs in per_worker_params_rref.items():
            remote_optim_rref = rpc.remote(
                worker,
                _LocalOptimizer,
                args=[optimizer_class, param_rrefs] + list(args),
                kwargs=kwargs,
            )
            self.remote_optimizers.append(remote_optim_rref)
Exemplo n.º 15
0
def init_empty_workers(num_workers, worker_constructor, worker_args):
    empty = {'jobs': 0, 'start': None, 'end': None}
    
    rrefs = [
        rpc.remote(
            'worker'+str(i),
            worker_constructor,
            args=(LOAD_FN, empty, *worker_args),
            kwargs={'head': i==0}
        )
        for i in range(num_workers)
    ]

    return rrefs
Exemplo n.º 16
0
    def test_rref_str(self):
        rref1 = RRef(self.rank)
        id_class = "GloballyUniqueId"
        self.assertEqual(
            "OwnerRRef({}({}, 0))".format(id_class, self.rank),
            rref1.__str__()
        )

        dst_rank = (self.rank + 1) % self.world_size
        rref2 = rpc.remote("worker{}".format(dst_rank), torch.add, args=(torch.ones(2, 2), 1))
        self.assertEqual(
            rref2.__str__(),
            "UserRRef(RRefId = {0}({1}, 1), ForkId = {0}({1}, 2))".format(id_class, self.rank)
        )
Exemplo n.º 17
0
 def get_rrefs(self, base_name, base_id, num_nodes, worker=True):
     """ get rrefs of remote machines (workers or servers)
     Args
     base_name     template name of deployed workers
     base_id       the lowest rank of a node in the deployment
     num_nodes     the number of nodes of which the server should fetch references
     """
     rrefs = [
         remote(base_name + str(node_id),
                get_worker if worker else get_server)
         for node_id in range(base_id, base_id + num_nodes)
     ]
     types = [type(rref.to_here()) for rref in rrefs]
     return types, rrefs
Exemplo n.º 18
0
    def test_rref_to_here_timeout(self):
        if self.rank != 0:
            return

        dst_rank = (self.rank + 1) % self.world_size
        dst_worker = "worker{}".format(dst_rank)
        rref = rpc.remote(
            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
        )
        expected_error = self.get_timeout_error_regex()
        with self.assertRaisesRegex(RuntimeError, expected_error):
            rref.to_here(0.01)

        rref.to_here()
Exemplo n.º 19
0
def init_workers(num_workers, start, end, delta, isTe, worker_constructor, worker_args):
    kwargs = get_work_units(num_workers, start, end, delta, isTe)

    rrefs = []
    for i in range(len(kwargs)):
        rrefs.append(
            rpc.remote(
                'worker'+str(i),
                worker_constructor,
                args=(LOAD_FN, kwargs[i], *worker_args),
                kwargs={'head': i==0}
            )
        )

    return rrefs
 def __init__(self, world_size):
     self.ob_rrefs = []
     self.agent_rref = RRef(self)
     self.rewards = {}
     self.saved_log_probs = {}
     self.policy = Policy()
     self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
     self.eps = np.finfo(np.float32).eps.item()
     self.running_reward = 0
     self.reward_threshold = DummyEnv().reward_threshold
     for ob_rank in range(1, world_size):
         ob_info = rpc.get_worker_info(worker_name(ob_rank))
         self.ob_rrefs.append(remote(ob_info, Observer))
         self.rewards[ob_info.id] = []
         self.saved_log_probs[ob_info.id] = []
Exemplo n.º 21
0
 def test_remote_timeout_to_here_in_jit(self):
     # Test that calling to_here() in JIT will raise timeout error if
     # rpc.remote failed.
     if self.rank != 0:
         return
     dst_rank = (self.rank + 1) % self.world_size
     dst_worker = "worker{}".format(dst_rank)
     rref = rpc.remote(
         dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
     )
     # Will ensure error handling callbacks are run.
     wait_until_pending_futures_and_users_flushed()
     # Call to_here() within a ScriptFunction and ensure it raises
     with self.assertRaisesRegex(RuntimeError, "RRef creation"):
         rref_to_here(rref)
Exemplo n.º 22
0
 def __init__(self, world_size):
     self.ob_rrefs = []
     self.agent_rref = RRef(self)
     self.rewards = {}
     self.saved_log_probs = {}
     self.policy = Policy()
     self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
     self.eps = numpy.finfo(numpy.float32).eps.item()
     self.running_reward = 0
     self.reward_threshold = gym.make(ENV).spec.reward_threshold
     for ob_rank in range(1, world_size):
         ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
         self.ob_rrefs.append(remote(ob_info, Observer))
         self.rewards[ob_info.id] = []
         self.saved_log_probs[ob_info.id] = []
Exemplo n.º 23
0
    def _test_remote_message_dropped_timeout(self, func, args, dst=None):
        if self.rank != 0:
            return

        # test the case where rpc.remote() message creation is completely dropped.
        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
        dst_worker = "worker{}".format(dst_rank)
        # Since we fail python_remote_call messages synchronously, the future
        # corresponding to this remote call will be marked with an error when
        # this function returns.
        rref = rpc.remote(dst_worker, func, args=args)
        # Call to ensure pending callbacks are run.
        wait_until_pending_futures_and_users_flushed()
        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
            rref.to_here()
Exemplo n.º 24
0
    def test_async_function_remote_multi(self):
        dst1 = worker_name((self.rank + 1) % self.world_size)
        dst2 = worker_name((self.rank + 2) % self.world_size)

        num = 20
        rrefs = []
        for i in range(num):
            rrefs.append(
                rpc.remote(dst1,
                           async_add,
                           args=(dst2, torch.ones(2,
                                                  2), torch.ones(2, 2) * i)))

        for i in range(num):
            self.assertEqual(rrefs[i].to_here(), torch.ones(2, 2) + i)
Exemplo n.º 25
0
 def _exec_func(self, exec_mode, method, *args):
     if ExecMode.LOCAL == exec_mode:
         if len(args) == 1 and isinstance(args[0], list):
             return method(*args[0])
         return method(*args)
     elif ExecMode.RPC_SYNC == exec_mode:
         return rpc.rpc_sync('worker{}'.format(self._next_rank()),
                             method,
                             args=(args))
     elif ExecMode.REMOTE == exec_mode:
         rref = rpc.remote('worker{}'.format(self._next_rank()),
                           method,
                           args=(args))
         return rref.to_here().wait()
     else:
         raise ValueError("Unrecognized ExecMode {}".format(exec_mode))
Exemplo n.º 26
0
def remote_method(method, obj_rref, *args, **kwargs):
    """
    Call rpc.remote on a method in a remote object.

    Args:
        method: the method (for example, Class.method)
        obj_rref (RRef): remote reference to the object
        args: positional arguments to pass to the method
        kwargs: keyword arguments to pass to the method

    Returns a RRef to the remote method call result.
    """
    return rpc.remote(obj_rref.owner(),
                      _call_method,
                      args=[method, obj_rref] + list(args),
                      kwargs=kwargs)
Exemplo n.º 27
0
 def __init__(self, config, world_size):
     self.e = 0
     self.config = config.config_NeuralPlayer
     self.preprocessor = None
     self._init_dataset(self.config.config_Datasets)
     self._init_agent(self.config.config_Agent)
     self.agent_rref = RRef(self.agent)
     self.world_size = world_size  #nb of remote agents
     self.worker_rrefs = []
     self.data_gatherer = ScoreDataGatherer()
     for worker_rank in range(1, self.world_size):
         worker_info = rpc.get_worker_info(f"worker{worker_rank}")
         self.worker_rrefs.append(
             remote(worker_info,
                    CentralAgentWorker,
                    args=(config, worker_rank),
                    timeout=600))
Exemplo n.º 28
0
    def run_training_loop(self, N_iter, coord_rref):
        self.iter = 0
        for i in range(N_iter):
            #create world_size-1 EpisodeRunner objects
            rref_list = [
                rpc.remote(f"rank{j}", EpisodeRunner, (j, ))
                for j in range(1, world_size)
            ]

            #launch episodes
            fut_list = [
                r.rpc_async().run_episode(coord_rref) for r in rref_list
            ]
            [fut.wait() for fut in fut_list]

            #update model
            self.update_model()
Exemplo n.º 29
0
    def __init__(self, graph: PipelineModulesGraph, chunks: int = 1, checkpoint: str = "except_last",) -> None:
        super().__init__()

        check_pytorch_version()

        chunks = int(chunks)
        checkpoint = str(checkpoint)

        if chunks <= 0:
            raise ValueError("number of chunks must be positive integer")
        if checkpoint not in ["always", "except_last", "never"]:
            raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")

        self.chunks = chunks
        # The micro-batch index where the checkpointing stops.
        checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[checkpoint]

        self.partitions = [
            self.Partition(
                nodes,
                rpc.remote(
                    handler.owner(),
                    PartitionHandler,
                    args=(
                        handler,
                        nodes[0].module.device,
                        len(nodes[0].inputs),
                        nodes[-1].num_outputs,
                        i,
                        self.chunks,
                        checkpoint_stop,
                    ),
                ),
            )
            for i, (nodes, handler) in enumerate(graph.partition_graph())
        ]
        self.input_consumers = [
            next(
                self.DataConsumer(partition, input_consumer.consumer_input_idx, input_consumer.output_idx)
                for partition in self.partitions
                if partition.nodes[0] is input_consumer.consumer
            )
            for input_consumer in graph.model_input_consumers
        ]

        self.graph = graph
Exemplo n.º 30
0
    def test_rref_forward_chain(self):
        ttl = 8
        n = self.rank + 1
        dst_rank = n % self.world_size

        rref = rpc.remote(
            "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
        )

        ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl)

        for i in range(ttl):
            self.assertEqual(len(ret_rref), 1)
            ret_rref = ret_rref[0].to_here()

        ret = ret_rref
        self.assertEqual(ret, torch.add(torch.ones(n, n), 1))