Ejemplo n.º 1
0
 def synchronize(self):
     futs = []
     sync_time = 0.0
     for name, param in self.model.named_parameters():
         fut = remote_method_async(ParameterServer.synchronize, self.param_server_rref, self.rank,
                                   name, copy.deepcopy(param.data).to("cpu"))
         futs.append(fut)
     self.sync_future_all = torch.futures.collect_all(futs)
     self.sync_future_all.then(lambda x: remote_method_async(
         ParameterServer.sync_counter, self.param_server_rref
     ))
     return sync_time
Ejemplo n.º 2
0
 def synchronize(self):
     # print(f"Worker {self.rank} starts synchronizing")
     futs = []
     static_state_dict_sync = copy.deepcopy(self.model)
     for name, param in static_state_dict_sync.items():
         # print(f"Worker {self.rank} starts to synchronize {name}")
         fut = remote_method_async(ParameterServer.synchronize, self.param_server_rref, self.rank,
                                   name, copy.deepcopy(param.data).to("cpu"))
         futs.append(fut)
     self.sync_future_all = torch.futures.collect_all(futs)
     self.sync_future_all.then(lambda x: remote_method_async(
         ParameterServer.sync_counter, self.param_server_rref
     ))
Ejemplo n.º 3
0
 def broadcast_state(self):
     broadcast_futs = []
     v_train = (1 + self.dyn_task) / self.dyn_timer
     self.dyn_task = 0.9 * v_train / v_train.min() - 0.9
     self.dyn_timer *= 0.
     for w_rank, worker_rref in enumerate(self.embedding_list):
         extra_work_fut = remote_method_async(TrainerNet.recv_extra_work, worker_rref, self.dyn_task[w_rank])
         broadcast_futs.append(extra_work_fut)
         for name, param in self.model.named_parameters():
             broad_fut = remote_method_async(TrainerNet.recv_state, worker_rref, name,
                                             param.data.to("cpu"))
             broadcast_futs.append(broad_fut)
     self.broadcast_fut_all = torch.futures.collect_all(broadcast_futs)
     self.broadcast_fut_all.then(lambda x: self.set_cluster_ready())
Ejemplo n.º 4
0
 def synchronize(self):
     # print(f"Worker {self.rank} starts synchronizing")
     futs = []
     sync_time = 0.0
     for name, param in self.model.named_parameters():
         # print(f"Worker {self.rank} starts to synchronize {name}")
         sync_time += (1.6e-06 * param.size().numel())
         fut = remote_method_async(ParameterServer.synchronize, self.param_server_rref, self.rank,
                                   name, copy.deepcopy(param.data).to("cpu"))
         futs.append(fut)
     self.sync_future_all = torch.futures.collect_all(futs)
     self.sync_future_all.then(lambda x: remote_method_async(
         ParameterServer.sync_counter, self.param_server_rref
     ))
     return sync_time
Ejemplo n.º 5
0
def run_driver(rank, world_size, gpu_list, dataset, batch_size,
               lr, mom, lambd, max_epoch, client_epoch, model, seed, q,
               early_stop_round, early_stop_metric):
    exp_id = str(int(time.time()))
    print(f"Driver initializing RPC, rank {rank}, world size {world_size}")
    rpc.init_rpc(name="driver", rank=rank, world_size=world_size)
    print("Initialized driver")
    param_server_rref = rpc.remote("parameter_server", get_parameter_server,
                                   args=(gpu_list[0], world_size - 1, dataset, batch_size,
                                         lr, mom, lambd, model, max_epoch, client_epoch,
                                         seed, exp_id, early_stop_round, early_stop_metric))
    for _rank in range(1, world_size - 1):
        print(f"Driver registering worker node {_rank}")
        worker_server_rref = rpc.remote(f"trainer_{_rank}", get_worker,
                                        args=(gpu_list[_rank], _rank, world_size - 1, dataset,
                                              model, batch_size, lr, seed, exp_id))
        print(f"Driver binding worker {_rank} with param server")
        remote_method(ParameterServer.embedding_workers, param_server_rref, worker_server_rref)
        remote_method(TrainerNet.embedding_param_server, worker_server_rref, param_server_rref)

    fut = remote_method_async(ParameterServer.instruct_training, param_server_rref)
    fut.wait()
    final_accuracy = remote_method(ParameterServer.get_final_accuract, param_server_rref)
    q.put(final_accuracy)
    rpc.shutdown()
    print("RPC shutdown on Driver")
Ejemplo n.º 6
0
    def instruct_training(self):
        total_train_time = 0.
        acc_list = []

        if wandb_enable:
            wandb.watch(self.model)

        for curr_epoch in range(self.max_epoch):
            logger.info(f"PS instructs training for epoch {curr_epoch}")
            epoch_time = []
            futs = []
            for worker_rref in self.embedding_list:
                fut = remote_method_async(TrainerNet.train_locally,
                                          worker_rref)
                futs.append(fut)
            for fut in futs:
                train_time, comm_time = fut.wait()
                epoch_time.append(train_time + comm_time)
            total_train_time += max(epoch_time)
            logger.info(
                f"Cluster finished training for epoch {curr_epoch}, max epoch {self.max_epoch - 1}"
            )
            logger.info(f"Epoch {curr_epoch} takes {max(epoch_time)} seconds")

            acc = get_accuracy(self.test_loader, self.model, self.device,
                               self.model_name == "transformer")
            logger.info(f"Accuracy: {acc}")
            acc_list.append(acc)
            if wandb_enable:
                wandb.log({"accuracy": acc}, step=curr_epoch)
                wandb.log(
                    {
                        "training time": max(epoch_time),
                        "total train time": total_train_time
                    },
                    step=curr_epoch)

        logger.info("Training complete!")
        acc = get_accuracy(self.test_loader, self.model, self.device,
                           self.model_name == "transformer")
        acc_list.append(acc)
        logger.info(f"Total train time: {total_train_time}")
        logger.info(f"Best accuracy {max(acc_list)}")
        if wandb_enable:
            wandb.log({"accuracy": acc}, step=self.max_epoch - 1)
            wandb.finish()
Ejemplo n.º 7
0
    def instruct_training(self):
        total_train_time = 0.
        # pre-initialization
        for name, param in self.model.named_parameters():
            self.wtminus1[name] = copy.deepcopy(param.data)
            self.mom_buffer[name] = torch.zeros(param.data.shape).to(self.device)

        if wandb_enable:
            wandb.watch(self.model)

        for curr_epoch in range(self.max_epoch):
            if curr_epoch == int(self.early_stop_round) * self.client_epoch \
                    and len(self.acc_list) != 0 \
                    and max(self.acc_list[-5:]) < self.early_stop_metric:
                print(f"{curr_epoch // self.client_epoch - 1}: {max(self.acc_list[-5:])}, "
                      f"This experiment seems to be bad, early stopping")
                self.max_epoch = curr_epoch
                break

            if curr_epoch % self.client_epoch == 0:
                logger.info(f"PS instructs training for epoch {curr_epoch // self.client_epoch}")
                epoch_time = 0.
            cluster_is_ready = self.cluster_is_ready & (curr_epoch > 0) & ((curr_epoch + 1) % self.client_epoch != 0)
            self.cluster_is_ready = (not cluster_is_ready) | (curr_epoch == 0)
            if curr_epoch > 0 and curr_epoch % self.client_epoch == 0:
                # nestorv_reversed_model = copy.deepcopy(self.model).to(self.device)
                # for name, param in nestorv_reversed_model.named_parameters():
                #     param.data = param.data + self.lr * self.mom * self.mom_buffer[name]
                acc = get_accuracy(self.test_loader, self.model, self.device)
                # del nestorv_reversed_model
                logger.info(f"Accuracy: {acc}")
                self.acc_list.append(acc)
                if wandb_enable:
                    wandb.log({"accuracy": acc}, step=curr_epoch // self.client_epoch - 1)

            # epoch_start_time = time.time()
            futs = []
            train_time_list = [0]
            for worker_rref in self.embedding_list:
                fut = remote_method_async(TrainerNet.train_locally, worker_rref,
                                          curr_epoch, cluster_is_ready)
                futs.append(fut)

            for fut in futs:
                w_rank, train_time, comm_time = fut.wait()
                self.dyn_timer[w_rank - 1] += train_time
                train_time_list.append(train_time)
            # epoch_end_time = time.time()
            epoch_time += max(train_time_list)
            if (curr_epoch + 1) % self.client_epoch == 0:
                total_train_time += epoch_time
                logger.info(f"Cluster finished training for epoch {curr_epoch // self.client_epoch}, "
                            f"max epoch {self.max_epoch // self.client_epoch - 1}")
                logger.info(f"Epoch {curr_epoch // self.client_epoch} takes {epoch_time} seconds")
                if wandb_enable:
                    wandb.log({"training time": epoch_time, "total train time": total_train_time}, step=curr_epoch // self.client_epoch)

        self.cluster_is_ready = False
        for worker_rref in self.embedding_list:
            remote_method_async(TrainerNet.synchronize, worker_rref)
        while not self.cluster_is_ready:
            print(f"PS waiting for final synchronization")
            time.sleep(1)
        logger.info("Training complete!")

        # nestorv_reversed_model = copy.deepcopy(self.model).to(self.device)
        # for name, param in nestorv_reversed_model.named_parameters():
        #     param.data = param.data + self.lr * self.mom * self.mom_buffer[name]
        acc = get_accuracy(self.test_loader, self.model, self.device)
        # del nestorv_reversed_model
        logger.info(f"Accuracy: {acc}")
        self.acc_list.append(acc)
        logger.info(f"Total train time: {total_train_time}")
        logger.info(f"Best accuracy {max(self.acc_list)}")
        if wandb_enable:
            wandb.log({"accuracy": acc}, step=self.max_epoch // self.client_epoch - 1)
            wandb.finish()