def test_nested_remote(self): n = self.rank + 1 dst_rank1 = n % self.world_size dst_rank2 = (n + 1) % self.world_size rref = rpc.remote( "worker{}".format(dst_rank1), nested_remote, args=("worker{}".format(dst_rank2), ), ) self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3)
def _test_remote_message_delay_timeout(self, func, args, dst=None): if self.rank != 0: return # Test the case where remote message is eventually processed on the owner, # but the future on the creator times out before the response comes back. dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) # 10 ms timeout rref = rpc.remote(dst_worker, func, args=args, timeout=0.001) # Future corresponding to the remote creation should time out. expected_error = self.get_timeout_error_regex() with self.assertRaisesRegex(RuntimeError, expected_error): rref._get_future().wait() # Call to ensure pending callbacks are run. wait_until_pending_futures_and_users_flushed() # to_here() should now pick up that rpc.remote() creation has failed. with self.assertRaisesRegex(RuntimeError, "RRef creation"): rref.to_here() # Test the case where rpc.remote() times out, but to_here() has already # started blocking before. # NOTE: we only test this when not sending to self, as to_here() calls # calls localValue(), which does not send an RPC and thus does not have # a timeout. This can be supported by allowing future.wait() to # take in an optional timeout (https://github.com/pytorch/pytorch/issues/39280) if dst_rank != self.rank: slow_rref = rpc.remote(dst_worker, func, args=args, timeout=2) with self.assertRaisesRegex(RuntimeError, expected_error): # to_here() should raise timeout error, since it does not know about the # status of rpc.remote(). slow_rref.to_here(0.001) # Note: If we proceed with shutdown, UserRRef will send out a RRefUserDelete # but this can be a noop since it may not exist on the owner yet. Later, # the owner can process the RRef creation and wait for the delete message, # thus leading to a timeout. # Therefore, we wait until we get notification that pending owners have # been confirmed before sending out RRefUserDeletes. if dst_rank != self.rank: wait_until_owners_and_forks_on_rank(2, 2, rank=dst_rank)
def test_py_udf_remote(self): n = self.rank + 1 dst_rank = n % self.world_size rref = rpc.remote( "worker{}".format(dst_rank), my_function, kwargs={ "a": n, "b": n + 1, "c": n + 2 }, ) self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
def run_driver(rank, world_size, gpu_list, dataset, batch_size, lr, max_epoch, client_epoch, model, seed): exp_id = str(int(time.time())) print(f"Driver initializing RPC, rank {rank}, world size {world_size}") rpc.init_rpc(name="driver", rank=rank, world_size=world_size) print("Initialized driver") param_server_rref = rpc.remote("parameter_server", get_parameter_server, args=(gpu_list[0], world_size - 1, dataset, batch_size, lr, model, max_epoch, client_epoch, seed, exp_id)) for _rank in range(1, world_size - 1): print(f"Driver registering worker node {_rank}") worker_server_rref = rpc.remote(f"trainer_{_rank}", get_worker, args=(gpu_list[_rank], _rank, world_size - 1, dataset, model, batch_size, lr, client_epoch, seed, exp_id)) print(f"Driver binding worker {_rank} with param server") remote_method(ParameterServer.embedding_workers, param_server_rref, worker_server_rref) remote_method(TrainerNet.embedding_param_server, worker_server_rref, param_server_rref) fut = remote_method_async(ParameterServer.instruct_training, param_server_rref) fut.wait() rpc.shutdown() print("RPC shutdown on Driver")
def init_workers(num_workers, h_dim, start, end, delta, isTe): kwargs = get_work_units(num_workers, start, end, delta, isTe) rrefs = [] for i in range(len(kwargs)): rrefs.append( rpc.remote( 'worker' + str(i), get_remote_gae, args=(h_dim, ld.load_lanl_dist, kwargs[i]), )) return rrefs
def test_nested_rref(self): n = self.rank + 1 dst_rank1 = n % self.world_size dst_rank2 = (n + 1) % self.world_size rref_of_rrefs = rpc.remote( "worker{}".format(dst_rank1), nested_rref, args=("worker{}".format(dst_rank2),), ) rrefs = rref_of_rrefs.to_here() self.assertEqual(len(rrefs), 2) self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1) self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
def test_rref_to_here_timeout_in_jit(self): if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) rref = rpc.remote(dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))) expected_error = get_timeout_error_regex( dist_utils.TEST_CONFIG.rpc_backend_name) with self.assertRaisesRegex(RuntimeError, expected_error): rref_to_here_with_timeout(rref, 0.01)
def create_clients(self, client_id_triple): for id, rank, world_size in client_id_triple: client = rpc.remote(id, Client, kwargs=dict(id=id, log_rref=self.log_rref, rank=rank, world_size=world_size, config=self.config)) writer = SummaryWriter( f'{self.tb_path}/{self.config.experiment_prefix}_client_{id}') self.clients.append( ClientRef(id, client, tensorboard_writer=writer)) self.client_data[id] = []
def test_rref_timeout_pickle_script_func(self): # Similar to above test, but calls python rpc with script function. if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) rref = rpc.remote( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) # Will ensure error handling callbacks are run. wait_until_pending_futures_and_users_flushed() # Call RPC with script function that takes RRef, ensure timeout during pickling with self.assertRaisesRegex(RuntimeError, "RRef creation"): rpc.rpc_sync(dst_worker, rref_to_here, args=(rref, ))
def test_rref_timeout_pickle_in_jit(self): if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) rref = rpc.remote( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) # Will ensure error handling callbacks are run. wait_until_pending_futures_and_users_flushed() # Call RPC with RRef arg in JIT, which will go through JIT pickling and # ensure error is raised. with self.assertRaisesRegex(RuntimeError, "RRef creation"): rpc_async_with_rref_arg(dst_worker, (rref, ))
def test_rref_as_arg_and_return(self): n = self.rank + 1 dst_rank = n % self.world_size local_ret = one_arg(torch.ones(2, 2)) # create rref on current rank rref = rpc.remote(worker_name(self.rank), one_arg, args=(torch.ones(2, 2),)) # pass rref to another user in rpc call ret = rpc.rpc_sync(worker_name(dst_rank), rref_to_here, args=(rref,)) self.assertEqual(ret, local_ret) # return rref in rpc call rref1 = rpc.rpc_sync(worker_name(dst_rank), return_rref, args=(rref,)) self.assertEqual(rref1.to_here(), local_ret) # pass rref to another user in remote call rref2 = rpc.remote(worker_name(dst_rank), rref_to_here, args=(rref,)) self.assertEqual(rref2.to_here(), local_ret) # return rref in remote call rref3 = rpc.remote(worker_name(dst_rank), return_rref, args=(rref,)) self.assertEqual(rref3.to_here().to_here(), local_ret)
def _test_self_remote_rref_as_rpc_arg(self, dst): self_worker_info = rpc.get_worker_info() rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3)) fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, torch.ones(2, 2))) ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, torch.ones(2, 2) + 1)) self.assertEqual(ret, torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2) + 1) self.assertEqual(fut.wait(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2))
def test_pass_local_rrefs(self): n = self.rank + 1 dst_rank = n % self.world_size dst_worker = "worker{}".format(dst_rank) rref = RRef(40) self.assertEqual( rpc.rpc_sync(dst_worker, add_rref_to_value, args=(rref, 50)), 90) self.assertEqual( rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 50)).wait(), 90) self.assertEqual( rpc.remote(dst_worker, add_rref_to_value, args=(rref, 50)).to_here(), 90)
def __init__(self, optimizer_class, params_rref, *args, **kwargs): per_worker_params_rref = defaultdict(list) for param in params_rref: per_worker_params_rref[param.owner()].append(param) self.remote_optimizers = [] for worker, param_rrefs in per_worker_params_rref.items(): remote_optim_rref = rpc.remote( worker, _LocalOptimizer, args=[optimizer_class, param_rrefs] + list(args), kwargs=kwargs, ) self.remote_optimizers.append(remote_optim_rref)
def init_empty_workers(num_workers, worker_constructor, worker_args): empty = {'jobs': 0, 'start': None, 'end': None} rrefs = [ rpc.remote( 'worker'+str(i), worker_constructor, args=(LOAD_FN, empty, *worker_args), kwargs={'head': i==0} ) for i in range(num_workers) ] return rrefs
def test_rref_str(self): rref1 = RRef(self.rank) id_class = "GloballyUniqueId" self.assertEqual( "OwnerRRef({}({}, 0))".format(id_class, self.rank), rref1.__str__() ) dst_rank = (self.rank + 1) % self.world_size rref2 = rpc.remote("worker{}".format(dst_rank), torch.add, args=(torch.ones(2, 2), 1)) self.assertEqual( rref2.__str__(), "UserRRef(RRefId = {0}({1}, 1), ForkId = {0}({1}, 2))".format(id_class, self.rank) )
def get_rrefs(self, base_name, base_id, num_nodes, worker=True): """ get rrefs of remote machines (workers or servers) Args base_name template name of deployed workers base_id the lowest rank of a node in the deployment num_nodes the number of nodes of which the server should fetch references """ rrefs = [ remote(base_name + str(node_id), get_worker if worker else get_server) for node_id in range(base_id, base_id + num_nodes) ] types = [type(rref.to_here()) for rref in rrefs] return types, rrefs
def test_rref_to_here_timeout(self): if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) rref = rpc.remote( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) expected_error = self.get_timeout_error_regex() with self.assertRaisesRegex(RuntimeError, expected_error): rref.to_here(0.01) rref.to_here()
def init_workers(num_workers, start, end, delta, isTe, worker_constructor, worker_args): kwargs = get_work_units(num_workers, start, end, delta, isTe) rrefs = [] for i in range(len(kwargs)): rrefs.append( rpc.remote( 'worker'+str(i), worker_constructor, args=(LOAD_FN, kwargs[i], *worker_args), kwargs={'head': i==0} ) ) return rrefs
def __init__(self, world_size): self.ob_rrefs = [] self.agent_rref = RRef(self) self.rewards = {} self.saved_log_probs = {} self.policy = Policy() self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2) self.eps = np.finfo(np.float32).eps.item() self.running_reward = 0 self.reward_threshold = DummyEnv().reward_threshold for ob_rank in range(1, world_size): ob_info = rpc.get_worker_info(worker_name(ob_rank)) self.ob_rrefs.append(remote(ob_info, Observer)) self.rewards[ob_info.id] = [] self.saved_log_probs[ob_info.id] = []
def test_remote_timeout_to_here_in_jit(self): # Test that calling to_here() in JIT will raise timeout error if # rpc.remote failed. if self.rank != 0: return dst_rank = (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) rref = rpc.remote( dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) ) # Will ensure error handling callbacks are run. wait_until_pending_futures_and_users_flushed() # Call to_here() within a ScriptFunction and ensure it raises with self.assertRaisesRegex(RuntimeError, "RRef creation"): rref_to_here(rref)
def __init__(self, world_size): self.ob_rrefs = [] self.agent_rref = RRef(self) self.rewards = {} self.saved_log_probs = {} self.policy = Policy() self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2) self.eps = numpy.finfo(numpy.float32).eps.item() self.running_reward = 0 self.reward_threshold = gym.make(ENV).spec.reward_threshold for ob_rank in range(1, world_size): ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank)) self.ob_rrefs.append(remote(ob_info, Observer)) self.rewards[ob_info.id] = [] self.saved_log_probs[ob_info.id] = []
def _test_remote_message_dropped_timeout(self, func, args, dst=None): if self.rank != 0: return # test the case where rpc.remote() message creation is completely dropped. dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size dst_worker = "worker{}".format(dst_rank) # Since we fail python_remote_call messages synchronously, the future # corresponding to this remote call will be marked with an error when # this function returns. rref = rpc.remote(dst_worker, func, args=args) # Call to ensure pending callbacks are run. wait_until_pending_futures_and_users_flushed() with self.assertRaisesRegex(RuntimeError, "RRef creation"): rref.to_here()
def test_async_function_remote_multi(self): dst1 = worker_name((self.rank + 1) % self.world_size) dst2 = worker_name((self.rank + 2) % self.world_size) num = 20 rrefs = [] for i in range(num): rrefs.append( rpc.remote(dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2) * i))) for i in range(num): self.assertEqual(rrefs[i].to_here(), torch.ones(2, 2) + i)
def _exec_func(self, exec_mode, method, *args): if ExecMode.LOCAL == exec_mode: if len(args) == 1 and isinstance(args[0], list): return method(*args[0]) return method(*args) elif ExecMode.RPC_SYNC == exec_mode: return rpc.rpc_sync('worker{}'.format(self._next_rank()), method, args=(args)) elif ExecMode.REMOTE == exec_mode: rref = rpc.remote('worker{}'.format(self._next_rank()), method, args=(args)) return rref.to_here().wait() else: raise ValueError("Unrecognized ExecMode {}".format(exec_mode))
def remote_method(method, obj_rref, *args, **kwargs): """ Call rpc.remote on a method in a remote object. Args: method: the method (for example, Class.method) obj_rref (RRef): remote reference to the object args: positional arguments to pass to the method kwargs: keyword arguments to pass to the method Returns a RRef to the remote method call result. """ return rpc.remote(obj_rref.owner(), _call_method, args=[method, obj_rref] + list(args), kwargs=kwargs)
def __init__(self, config, world_size): self.e = 0 self.config = config.config_NeuralPlayer self.preprocessor = None self._init_dataset(self.config.config_Datasets) self._init_agent(self.config.config_Agent) self.agent_rref = RRef(self.agent) self.world_size = world_size #nb of remote agents self.worker_rrefs = [] self.data_gatherer = ScoreDataGatherer() for worker_rank in range(1, self.world_size): worker_info = rpc.get_worker_info(f"worker{worker_rank}") self.worker_rrefs.append( remote(worker_info, CentralAgentWorker, args=(config, worker_rank), timeout=600))
def run_training_loop(self, N_iter, coord_rref): self.iter = 0 for i in range(N_iter): #create world_size-1 EpisodeRunner objects rref_list = [ rpc.remote(f"rank{j}", EpisodeRunner, (j, )) for j in range(1, world_size) ] #launch episodes fut_list = [ r.rpc_async().run_episode(coord_rref) for r in rref_list ] [fut.wait() for fut in fut_list] #update model self.update_model()
def __init__(self, graph: PipelineModulesGraph, chunks: int = 1, checkpoint: str = "except_last",) -> None: super().__init__() check_pytorch_version() chunks = int(chunks) checkpoint = str(checkpoint) if chunks <= 0: raise ValueError("number of chunks must be positive integer") if checkpoint not in ["always", "except_last", "never"]: raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'") self.chunks = chunks # The micro-batch index where the checkpointing stops. checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[checkpoint] self.partitions = [ self.Partition( nodes, rpc.remote( handler.owner(), PartitionHandler, args=( handler, nodes[0].module.device, len(nodes[0].inputs), nodes[-1].num_outputs, i, self.chunks, checkpoint_stop, ), ), ) for i, (nodes, handler) in enumerate(graph.partition_graph()) ] self.input_consumers = [ next( self.DataConsumer(partition, input_consumer.consumer_input_idx, input_consumer.output_idx) for partition in self.partitions if partition.nodes[0] is input_consumer.consumer ) for input_consumer in graph.model_input_consumers ] self.graph = graph
def test_rref_forward_chain(self): ttl = 8 n = self.rank + 1 dst_rank = n % self.world_size rref = rpc.remote( "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1) ) ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl) for i in range(ttl): self.assertEqual(len(ret_rref), 1) ret_rref = ret_rref[0].to_here() ret = ret_rref self.assertEqual(ret, torch.add(torch.ones(n, n), 1))