Beispiel #1
0
def run_driver(rank, world_size, gpu_list, dataset, batch_size,
               lr, mom, lambd, max_epoch, client_epoch, model, seed, q,
               early_stop_round, early_stop_metric):
    exp_id = str(int(time.time()))
    print(f"Driver initializing RPC, rank {rank}, world size {world_size}")
    rpc.init_rpc(name="driver", rank=rank, world_size=world_size)
    print("Initialized driver")
    param_server_rref = rpc.remote("parameter_server", get_parameter_server,
                                   args=(gpu_list[0], world_size - 1, dataset, batch_size,
                                         lr, mom, lambd, model, max_epoch, client_epoch,
                                         seed, exp_id, early_stop_round, early_stop_metric))
    for _rank in range(1, world_size - 1):
        print(f"Driver registering worker node {_rank}")
        worker_server_rref = rpc.remote(f"trainer_{_rank}", get_worker,
                                        args=(gpu_list[_rank], _rank, world_size - 1, dataset,
                                              model, batch_size, lr, seed, exp_id))
        print(f"Driver binding worker {_rank} with param server")
        remote_method(ParameterServer.embedding_workers, param_server_rref, worker_server_rref)
        remote_method(TrainerNet.embedding_param_server, worker_server_rref, param_server_rref)

    fut = remote_method_async(ParameterServer.instruct_training, param_server_rref)
    fut.wait()
    final_accuracy = remote_method(ParameterServer.get_final_accuract, param_server_rref)
    q.put(final_accuracy)
    rpc.shutdown()
    print("RPC shutdown on Driver")
Beispiel #2
0
 def synchronize(self):
     sync_time = 0.
     for name, param in self.model.named_parameters():
         remote_method(ParameterServer.synchronize, self.param_server_rref,
                       self.rank, name,
                       copy.deepcopy(param.data).to("cpu"))
     remote_method(ParameterServer.sync_counter, self.param_server_rref)
     return sync_time
 def embedding_param_server(self, _rref):
     self.param_server_rref = _rref
     if self.model_name == "transformer":
         remote_method(ParameterServer.worker_weight_update, self.param_server_rref, self.rank,
                       len(self.train_loader))
     else:
         remote_method(ParameterServer.worker_weight_update, self.param_server_rref, self.rank,
                       len(self.train_loader.dataset))
Beispiel #4
0
 def synchronize(self):
     sync_time = 0.
     for name, param in self.model.named_parameters():
         # print(f"Worker {self.rank} starts to synchronize {name}")
         remote_method(ParameterServer.synchronize, self.param_server_rref,
                       self.rank, name,
                       copy.deepcopy(param.data).to("cpu"))
         sync_time += (1.6e-06 * param.data.size().numel())
     # synchronizing completed, sending a signal.
     # print(f"Worker {self.rank} finished synchronizing")
     remote_method(ParameterServer.sync_counter, self.param_server_rref)
     return sync_time
Beispiel #5
0
 def fetch_state(self):
     fetch_time = 0.
     for name, param in self.model.named_parameters():
         param.data = copy.deepcopy(
             remote_method(ParameterServer.distribute_state,
                           self.param_server_rref, name)).to(self.device)
     return fetch_time
Beispiel #6
0
 def fetch_state_block(self):
     sync_time = 0.0
     for name, param in self.model.named_parameters():
         param.data = copy.deepcopy(remote_method(
             ParameterServer.distribute_state, self.param_server_rref, name
         )).to(self.device)
         sync_time += (1.6e-06 * param.data.size().numel())
     return sync_time
 def fetch_state_block(self):
     static_state_dict = copy.deepcopy(self.model.state_dict())
     for name, param in static_state_dict.items():
         static_state_dict[name] = copy.deepcopy(remote_method(
             ParameterServer.distribute_state, self.param_server_rref, name
         )).to(self.device)
         time.sleep(1.6e-06 * param.data.size().numel())
     self.model.load_state_dict(static_state_dict)
Beispiel #8
0
 def embedding_param_server(self, _rref):
     # print(f"Worker {self.rank} add a parameter server reference")
     self.param_server_rref = _rref
     remote_method(ParameterServer.worker_weight_update,
                   self.param_server_rref, self.rank,
                   len(self.train_loader.dataset))