def run_driver(rank, world_size, gpu_list, dataset, batch_size, lr, mom, lambd, max_epoch, client_epoch, model, seed, q, early_stop_round, early_stop_metric): exp_id = str(int(time.time())) print(f"Driver initializing RPC, rank {rank}, world size {world_size}") rpc.init_rpc(name="driver", rank=rank, world_size=world_size) print("Initialized driver") param_server_rref = rpc.remote("parameter_server", get_parameter_server, args=(gpu_list[0], world_size - 1, dataset, batch_size, lr, mom, lambd, model, max_epoch, client_epoch, seed, exp_id, early_stop_round, early_stop_metric)) for _rank in range(1, world_size - 1): print(f"Driver registering worker node {_rank}") worker_server_rref = rpc.remote(f"trainer_{_rank}", get_worker, args=(gpu_list[_rank], _rank, world_size - 1, dataset, model, batch_size, lr, seed, exp_id)) print(f"Driver binding worker {_rank} with param server") remote_method(ParameterServer.embedding_workers, param_server_rref, worker_server_rref) remote_method(TrainerNet.embedding_param_server, worker_server_rref, param_server_rref) fut = remote_method_async(ParameterServer.instruct_training, param_server_rref) fut.wait() final_accuracy = remote_method(ParameterServer.get_final_accuract, param_server_rref) q.put(final_accuracy) rpc.shutdown() print("RPC shutdown on Driver")
def synchronize(self): sync_time = 0. for name, param in self.model.named_parameters(): remote_method(ParameterServer.synchronize, self.param_server_rref, self.rank, name, copy.deepcopy(param.data).to("cpu")) remote_method(ParameterServer.sync_counter, self.param_server_rref) return sync_time
def embedding_param_server(self, _rref): self.param_server_rref = _rref if self.model_name == "transformer": remote_method(ParameterServer.worker_weight_update, self.param_server_rref, self.rank, len(self.train_loader)) else: remote_method(ParameterServer.worker_weight_update, self.param_server_rref, self.rank, len(self.train_loader.dataset))
def synchronize(self): sync_time = 0. for name, param in self.model.named_parameters(): # print(f"Worker {self.rank} starts to synchronize {name}") remote_method(ParameterServer.synchronize, self.param_server_rref, self.rank, name, copy.deepcopy(param.data).to("cpu")) sync_time += (1.6e-06 * param.data.size().numel()) # synchronizing completed, sending a signal. # print(f"Worker {self.rank} finished synchronizing") remote_method(ParameterServer.sync_counter, self.param_server_rref) return sync_time
def fetch_state(self): fetch_time = 0. for name, param in self.model.named_parameters(): param.data = copy.deepcopy( remote_method(ParameterServer.distribute_state, self.param_server_rref, name)).to(self.device) return fetch_time
def fetch_state_block(self): sync_time = 0.0 for name, param in self.model.named_parameters(): param.data = copy.deepcopy(remote_method( ParameterServer.distribute_state, self.param_server_rref, name )).to(self.device) sync_time += (1.6e-06 * param.data.size().numel()) return sync_time
def fetch_state_block(self): static_state_dict = copy.deepcopy(self.model.state_dict()) for name, param in static_state_dict.items(): static_state_dict[name] = copy.deepcopy(remote_method( ParameterServer.distribute_state, self.param_server_rref, name )).to(self.device) time.sleep(1.6e-06 * param.data.size().numel()) self.model.load_state_dict(static_state_dict)
def embedding_param_server(self, _rref): # print(f"Worker {self.rank} add a parameter server reference") self.param_server_rref = _rref remote_method(ParameterServer.worker_weight_update, self.param_server_rref, self.rank, len(self.train_loader.dataset))