def __init_opt(self): if self.config.SERVER_OPT is not None: self.server_opt = TorchOptRepo.name2cls(self.config.SERVER_OPT)( self.model.parameters(), lr=self.config.SERVER_LEARNING_RATE, **self.config.SERVER_OPT_ARGS, ) else: self.server_opt = None
def switch_to_sgd(self, lr): # TODO __del__ self.__delete_objects_tmp_files() # opt_cls = sgd self.opt_cls = TorchOptRepo.name2cls("SGD") # delete self.opt self.opt = None # delete __opt_state_to_be_loaded self.__opt_state_to_be_loaded = None # lr self.opt_cls_param = { "lr": lr, "weight_decay": self.opt_cls_param["weight_decay"], }
def __train_one_round(self, curr_round: int): self.model.train() client_sample = self.__select_clients() comm_avg_model_state = None comm_avg_opt_state = None if self.config.SCAFFOLD: comm_c = None logging.info("SCAFFOLD: comm_c initialized") with ElapsedTime("Training clients"): for i, client in enumerate(client_sample): client.set_model(self.model.state_dict()) if (self.config.CLIENT_OPT_STRATEGY == "avg") and ( self.avg_opt_state is not None ): client.set_opt_state(self.avg_opt_state) if self.config.SCAFFOLD: client.set_server_c(self.c) if self.config.SCAFFOLD: model_state, opt_state, c = client.train_round( self.config.N_EPOCH_PER_CLIENT, curr_round ) comm_c = commulative_avg_params(comm_c, c, i) logging.info("SCAFFOLD: c received from client") else: model_state, opt_state = client.train_round( self.config.N_EPOCH_PER_CLIENT, curr_round ) comm_avg_model_state = commulative_avg_model_state_dicts( comm_avg_model_state, model_state, i ) if self.config.CLIENT_OPT_STRATEGY == "avg": if comm_avg_opt_state is not None: comm_avg_opt_state = [ commulative_avg_model_state_dicts(opt_s[0], opt_s[1], i) for opt_s in zip(comm_avg_opt_state, opt_state) ] else: comm_avg_opt_state = opt_state with ElapsedTime("Setting gradients"): if self.server_opt is not None: self.server_opt.zero_grad() server_opt_state = self.server_opt.state_dict() self.__set_model_grads(comm_avg_model_state) self.server_opt = TorchOptRepo.name2cls(self.config.SERVER_OPT)( self.model.parameters(), lr=self.config.SERVER_LEARNING_RATE, **self.config.SERVER_OPT_ARGS, ) self.server_opt.load_state_dict(server_opt_state) self.server_opt.step() else: self.__log("setting avg model state") self.model.load_state_dict(comm_avg_model_state) if self.config.SCAFFOLD: self.c = lambda2_params( self.c, comm_c, lambda a, b: a + (b * len(client_sample) / self.config.N_CLIENTS), ) logging.info("SCAFFOLD: server c updated") self.avg_opt_state = comm_avg_opt_state comm_avg_opt_state = None
def __init__( self, experiment: Experiment, config: TorchFederatedLearnerConfig, config_technical: TorchFederatedLearnerTechnicalConfig, ) -> None: """Initialises the training. Arguments: experiment {Experiment} -- Comet.ml experiment object for online logging. config {TorchFederatedLearnerConfig} -- Training configuration description. """ super().__init__() if config.SEED is not None: random.seed(config.SEED) np.random.seed(config.SEED) th.manual_seed(config.SEED) th.backends.cudnn.deterministic = True th.backends.cudnn.benchmark = False self.device = th.device("cuda" if th.cuda.is_available() else "cpu") self.experiment = experiment self.config = config self.config.set_defaults() self.experiment.log_parameters(self.config.flatten()) self.config_technical = config_technical self.config_technical.check() self.experiment.log_parameters(self.config_technical.flatten()) self.PATH = self.tmp_dir / f"{self.experiment.id}_checkpoints" os.makedirs(self.PATH, exist_ok=True) model_cls = self.get_model_cls() self.model = model_cls().to(self.device) self.__init_opt() self.avg_opt_state = None if self.config.SCAFFOLD: self.c = lambda_params(self.model.parameters(), th.zeros_like) logging.info("SCAFFOLD: server c initialized") self.train_loader_list, self.test_loader, self.random_acc = self.load_data() self.n_train_batches = len(self.train_loader_list[0]) logging.info(f"Number of training batches: {self.n_train_batches}") TorchClient.reset_ID_counter() self.clients = [ TorchClient( self, model_cls=model_cls, is_keep_model_on_gpu=not self.config_technical.STORE_MODEL_IN_RAM, is_store_opt_on_disk=self.config_technical.STORE_OPT_ON_DISK, loss=self.get_loss(), dataloader=loader, device=self.device, opt_cls=TorchOptRepo.name2cls(self.config.CLIENT_OPT), opt_cls_param={ "lr": self.config.CLIENT_LEARNING_RATE, "weight_decay": self.config.CLIENT_OPT_L2, }, is_maintaine_opt_state=config.CLIENT_OPT_STRATEGY == "nothing", exp_id=experiment.id, is_scaffold=self.config.SCAFFOLD, ) for loader in self.train_loader_list ]