def aggregate(conf, master_model, fedavg_model, client_models, flatten_local_models): # perform the server momentum (either heavy-ball momentum or nesterov momentum) fl_aggregate = conf.fl_aggregate assert "server_momentum_factor" in fl_aggregate # start the server momentum acceleration. current_model_tb = TensorBuffer(list(fedavg_model.parameters())) previous_model_tb = TensorBuffer(list(master_model.parameters())) # get the update direction. update = previous_model_tb.buffer - current_model_tb.buffer # using server momentum for the update. if not hasattr(conf, "server_momentum_buffer"): conf.server_momentum_buffer = torch.zeros_like(update) conf.server_momentum_buffer.mul_( fl_aggregate["server_momentum_factor"]).add_(update) previous_model_tb.buffer.add_(-conf.server_momentum_buffer) # update the master_model (but will use the bn stats from the fedavg_model) master_model = fedavg_model _model_param = list(master_model.parameters()) previous_model_tb.unpack(_model_param) # free the memory. torch.cuda.empty_cache() # a temp hack (only for debug reason). client_models = dict((used_client_arch, master_model.cpu()) for used_client_arch in conf.used_client_archs) return client_models
def step(self, closure=None, **kargs): # Apply the gradients with the weight decay and momentum. with kargs["timer"]("grad.apply_grad", epoch=self.conf.epoch_): utils.apply_gradient( self.param_groups, self.state, apply_grad_to_model=True ) with kargs["timer"]("grad.get_params", epoch=self.conf.epoch_): params, _ = comm.get_data( self.param_groups, self.param_names, is_get_grad=False ) params_tb = TensorBuffer(params) with kargs["timer"]("grad.error_compensate", epoch=self.conf.epoch_): self.memory.buffer += params_tb.buffer with kargs["timer"]("grad.compress", epoch=self.conf.epoch_): sync_buffer = {"original_shapes": self.shapes, "params_tb": self.memory} local_compressed_params_tb = self.compressor.compress(sync_buffer) with kargs["timer"]("grad.update_memory", epoch=self.conf.epoch_): self.memory.buffer = self.memory.buffer - local_compressed_params_tb.buffer with kargs["timer"]("grad.sync", epoch=self.conf.epoch_): self.compressor.sync(sync_buffer) # update local model. with kargs["timer"]("grad.decompress", epoch=self.conf.epoch_): aggregated_info_tb = self.compressor.uncompress( sync_buffer, self.neighbors_info ) params_tb.buffer += aggregated_info_tb.buffer params_tb.unpack(params) return sync_buffer["n_bits"]
def step(self, closure=None, **kargs): if self.conf.is_centralized: with kargs["timer"]("sync/get_data", epoch=self.conf.epoch_): # Get data. grads, _ = comm.get_data(self.param_groups, self.param_names, is_get_grad=True) flatten_grads = TensorBuffer(grads) with kargs["timer"]("sync/sync", epoch=self.conf.epoch_): # Aggregate the gradients. flatten_grads.buffer = self.world_aggregator._agg( flatten_grads.buffer, op="avg", distributed=self.conf.distributed) with kargs["timer"]("sync/unflatten_grad", epoch=self.conf.epoch_): # unflatten grads. flatten_grads.unpack(grads) with kargs["timer"]("sync/apply_grad", epoch=self.conf.epoch_): utils.apply_gradient(self.param_groups, self.state, apply_grad_to_model=True) # Get n_bits to transmit. n_bits = get_n_bits(flatten_grads.buffer) else: with kargs["timer"]("sync/apply_grad", epoch=self.conf.epoch_): utils.apply_gradient(self.param_groups, self.state, apply_grad_to_model=True) with kargs["timer"]("sync/get_data", epoch=self.conf.epoch_): # first get and flatten all params. params, _ = comm.get_data(self.param_groups, self.param_names, is_get_grad=False) flatten_params = TensorBuffer(params) with kargs["timer"]("sync/sync", epoch=self.conf.epoch_): # prepare the sync. if self.conf.comm_device == "cpu": flatten_params.buffer.cpu().detach_() # then sync. flatten_params.buffer = self.decentralized_aggregator._agg( flatten_params.buffer, op="weighted") with kargs["timer"]("sync/update_model", epoch=self.conf.epoch_): # finally unflatten. flatten_params.unpack(params) # Get n_bits to transmit. n_bits = get_n_bits(flatten_params.buffer) return n_bits
def step(self, closure=None, **kargs): # Apply the gradients with the weight decay and momentum. with kargs["timer"]("grad.apply_grad", epoch=self.conf.epoch_): utils.apply_gradient(self.param_groups, self.state, apply_grad_to_model=False) with kargs["timer"]("grad.get_grads", epoch=self.conf.epoch_): params, _ = comm.get_data(self.param_groups, self.param_names, is_get_grad=False) flatten_params = TensorBuffer(params) grads, _ = comm.get_data(self.param_groups, self.param_names, is_get_grad=True) flatten_grads = TensorBuffer(grads) # Get weighted hat params and apply the local gradient. with kargs["timer"]("grad.apply_local_gradient", epoch=self.conf.epoch_): flatten_half_params = deepcopy(flatten_params) flatten_half_params.buffer = (sum([ _hat_params.buffer * self.neighbors_info[_rank] for _rank, _hat_params in self.neighbor_hat_params.items() ]) - self.param_groups[0]["lr"] * flatten_grads.buffer) # compress the model difference and sync. with kargs["timer"]("grad.compress", epoch=self.conf.epoch_): sync_buffer = { "original_shapes": self.shapes, "flatten_half_params": flatten_half_params, "flatten_params": flatten_params, } self.compressor.compress(sync_buffer) with kargs["timer"]("grad.sync", epoch=self.conf.epoch_): self.compressor.sync(sync_buffer) # finally unflatten and update local model. with kargs["timer"]("grad.unflatten_to_update", epoch=self.conf.epoch_): self.compressor.uncompress(sync_buffer, self.neighbor_hat_params) flatten_params.buffer = self.neighbor_hat_params[ self.rank].buffer.clone() flatten_params.unpack(params) return sync_buffer["n_bits"]
def step(self, closure=None, **kargs): with kargs["timer"]("sync.apply_grad", epoch=self.conf.epoch_): utils.apply_gradient(self.param_groups, self.state, apply_grad_to_model=False) with kargs["timer"]("sync.get_data", epoch=self.conf.epoch_): # Get data. grads, _ = comm.get_data(self.param_groups, self.param_names, is_get_grad=True) grads_tb = TensorBuffer(grads) with kargs["timer"]("sync.use_memory", epoch=self.conf.epoch_): # use memory. grads_tb.buffer.add_(self.memory_tb.buffer) with kargs["timer"]("sync.compress", epoch=self.conf.epoch_): # compress. sync_buffer = self.compressor.compress(grads_tb) with kargs["timer"]("sync.sync", epoch=self.conf.epoch_): self.compressor.sync(sync_buffer) with kargs["timer"]("sync.update_memory", epoch=self.conf.epoch_): # update memory. self.memory_tb.buffer = (grads_tb.buffer - sync_buffer["synced_grads_tb"].buffer) with kargs["timer"]("sync.decompress", epoch=self.conf.epoch_): sync_grads_tb = self.compressor.decompress(sync_buffer) with kargs["timer"]("sync.apply_grad", epoch=self.conf.epoch_): # appply the gradient but only with the gradient. params, _ = comm.get_data(self.param_groups, self.param_names, is_get_grad=False) params_tb = TensorBuffer(params) # apply the gradient. params_tb.buffer.add_(-self.param_groups[0]["lr"] * sync_grads_tb.buffer) # unpack. params_tb.unpack(params) return sync_buffer["n_bits"]
def aggregate(conf, master_model, fedavg_model, client_models, flatten_local_models): # perform the server Adam. # Following the setup in the paper, we use momentum of 0.9, # numerical stability constant epsilon to be 0.01, # the beta_2 is set to 0.99. # The suggested server_lr in the original paper is 0.1 fl_aggregate = conf.fl_aggregate assert "server_lr" in fl_aggregate beta_2 = fl_aggregate["beta_2"] if "beta_2" in fl_aggregate else 0.99 # start the server momentum acceleration. current_model_tb = TensorBuffer(list(fedavg_model.parameters())) previous_model_tb = TensorBuffer(list(master_model.parameters())) # get the update direction. update = previous_model_tb.buffer - current_model_tb.buffer # using server momentum for the update. if not hasattr(conf, "second_server_momentum_buffer"): conf.second_server_momentum_buffer = torch.zeros_like(update) conf.second_server_momentum_buffer.mul_(beta_2).add_( (1 - beta_2) * (update**2)) previous_model_tb.buffer.add_( -fl_aggregate["server_lr"] * update / (torch.sqrt(conf.second_server_momentum_buffer) + 0.01)) # update the master_model (but will use the bn stats from the fedavg_model) master_model = fedavg_model _model_param = list(master_model.parameters()) previous_model_tb.unpack(_model_param) # free the memory. torch.cuda.empty_cache() # a temp hack (only for debug reason). client_models = dict((used_client_arch, master_model.cpu()) for used_client_arch in conf.used_client_archs) return client_models
def step(self, closure=None, **kargs): # do the local update steps. with kargs["timer"]("sync/get_data", epoch=self.conf.epoch_): # get parmas. params, _ = comm.get_data(self.param_groups, self.param_names, is_get_grad=False) params_tb = TensorBuffer(params) with kargs["timer"]("sync/apply_grad", epoch=self.conf.epoch_): # prepare the gradient (sign) utils.apply_gradient(self.param_groups, self.state, apply_grad_to_model=False) # get grads. grads, _ = comm.get_data(self.param_groups, self.param_names, is_get_grad=True) grads_tb = TensorBuffer(grads) # enter the global sync if it satisfies the condition. # get the params difference w.r.t. previous synced model. with kargs["timer"]("sync/compress", epoch=self.conf.epoch_): sync_buffer = self.compressor.compress(grads_tb) # sync and decompress. with kargs["timer"]("sync/sync_and_decompress", epoch=self.conf.epoch_): self.compressor.sync(sync_buffer) synced_updates_tb = self.compressor.decompress(sync_buffer) # unpack the synced info and update the consensus params. with kargs["timer"]("sync/apply_grad", epoch=self.conf.epoch_): params_tb.buffer -= self.param_groups[0][ "lr"] * synced_updates_tb.buffer params_tb.unpack(params) return sync_buffer["n_bits"]
class Worker(object): def __init__(self, conf): self.conf = conf # some initializations. self.rank = conf.graph.rank conf.graph.worker_id = conf.graph.rank self.device = torch.device( "cuda" if self.conf.graph.on_cuda else "cpu") # define the timer for different operations. # if we choose the `train_fast` mode, then we will not track the time. self.timer = Timer( verbosity_level=1 if conf.track_time and not conf.train_fast else 0, log_fn=conf.logger.log_metric, ) # create dataset (as well as the potential data_partitioner) for training. dist.barrier() self.dataset = create_dataset.define_dataset(conf, data=conf.data) _, self.data_partitioner = create_dataset.define_data_loader( self.conf, dataset=self.dataset["train"], localdata_id=0, # random id here. is_train=True, data_partitioner=None, ) conf.logger.log( f"Worker-{self.conf.graph.worker_id} initialized the local training data with Master." ) # define the criterion. self.criterion = nn.CrossEntropyLoss(reduction="mean") # define the model compression operators. if conf.local_model_compression is not None: if conf.local_model_compression == "quantization": self.model_compression_fn = compressor.ModelQuantization(conf) conf.logger.log( f"Worker-{conf.graph.worker_id} initialized dataset/criterion.\n") def run(self): while True: self._listen_to_master() # check if we need to terminate the training or not. if self._terminate_by_early_stopping(): return self._recv_model_from_master() self._train() self._send_model_to_master() # check if we need to terminate the training or not. if self._terminate_by_complete_training(): return def _listen_to_master(self): # listen to master, related to the function `_activate_selected_clients` in `master.py`. msg = torch.zeros((3, self.conf.n_participated)) dist.broadcast(tensor=msg, src=0) self.conf.graph.client_id, self.conf.graph.comm_round, self.n_local_epochs = ( msg[:, self.conf.graph.rank - 1].to(int).cpu().numpy().tolist()) # once we receive the signal, we init for the local training. self.arch, self.model = create_model.define_model( self.conf, to_consistent_model=False, client_id=self.conf.graph.client_id) self.model_state_dict = self.model.state_dict() self.model_tb = TensorBuffer(list(self.model_state_dict.values())) self.metrics = create_metrics.Metrics(self.model, task="classification") dist.barrier() def _recv_model_from_master(self): # related to the function `_send_model_to_selected_clients` in `master.py` old_buffer = copy.deepcopy(self.model_tb.buffer) dist.recv(self.model_tb.buffer, src=0) new_buffer = copy.deepcopy(self.model_tb.buffer) self.model_tb.unpack(self.model_state_dict.values()) self.model.load_state_dict(self.model_state_dict) random_reinit.random_reinit_model(self.conf, self.model) self.init_model = self._turn_off_grad( copy.deepcopy(self.model).to(self.device)) self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) received the model ({self.arch}) from Master. The model status {'is updated' if old_buffer.norm() != new_buffer.norm() else 'is not updated'}." ) dist.barrier() def _train(self): self.model.train() # init the model and dataloader. if self.conf.graph.on_cuda: self.model = self.model.cuda() self.train_loader, _ = create_dataset.define_data_loader( self.conf, dataset=self.dataset["train"], # localdata_id start from 0 to the # of clients - 1. # client_id starts from 1 to the # of clients. localdata_id=self.conf.graph.client_id - 1, is_train=True, data_partitioner=self.data_partitioner, ) # define optimizer, scheduler and runtime tracker. self.optimizer = create_optimizer.define_optimizer( self.conf, model=self.model, optimizer_name=self.conf.optimizer) self.scheduler = create_scheduler.Scheduler(self.conf, optimizer=self.optimizer) self.tracker = RuntimeTracker( metrics_to_track=self.metrics.metric_names) self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) enters the local training phase (current communication rounds={self.conf.graph.comm_round})." ) # efficient local training. if hasattr(self, "model_compression_fn"): self.model_compression_fn.compress_model( param_groups=self.optimizer.param_groups) # entering local updates and will finish only after reaching the expected local_n_epochs. while True: for _input, _target in self.train_loader: # load data with self.timer("load_data", epoch=self.scheduler.epoch_): data_batch = create_dataset.load_data_batch( self.conf, _input, _target, is_training=True) # inference and get current performance. with self.timer("forward_pass", epoch=self.scheduler.epoch_): self.optimizer.zero_grad() loss, output = self._inference(data_batch) # in case we need self distillation to penalize the local training # (avoid catastrophic forgetting). self._local_training_with_self_distillation( loss, output, data_batch) with self.timer("backward_pass", epoch=self.scheduler.epoch_): loss.backward() self._add_grad_from_prox_regularized_loss() self.optimizer.step() self.scheduler.step() # efficient local training. with self.timer("compress_model", epoch=self.scheduler.epoch_): if hasattr(self, "model_compression_fn"): self.model_compression_fn.compress_model( param_groups=self.optimizer.param_groups) # display the logging info. display_training_stat(self.conf, self.scheduler, self.tracker) # display tracking time. if (self.conf.display_tracked_time and self.scheduler.local_index % self.conf.summary_freq == 0): self.conf.logger.log(self.timer.summary()) # check divergence. if self.tracker.stat["loss"].avg > 1e3 or np.isnan( self.tracker.stat["loss"].avg): self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) diverges!!!!!Early stop it." ) self._terminate_comm_round() return # check stopping condition. if self._is_finished_one_comm_round(): self._terminate_comm_round() return # refresh the logging cache at the end of each epoch. self.tracker.reset() if self.conf.logger.meet_cache_limit(): self.conf.logger.save_json() def _inference(self, data_batch): """Inference on the given model and get loss and accuracy.""" # do the forward pass and get the output. output = self.model(data_batch["input"]) # evaluate the output and get the loss, performance. if self.conf.use_mixup: loss = mixup.mixup_criterion( self.criterion, output, data_batch["target_a"], data_batch["target_b"], data_batch["mixup_lambda"], ) performance_a = self.metrics.evaluate(loss, output, data_batch["target_a"]) performance_b = self.metrics.evaluate(loss, output, data_batch["target_b"]) performance = [ data_batch["mixup_lambda"] * _a + (1 - data_batch["mixup_lambda"]) * _b for _a, _b in zip(performance_a, performance_b) ] else: loss = self.criterion(output, data_batch["target"]) performance = self.metrics.evaluate(loss, output, data_batch["target"]) # update tracker. if self.tracker is not None: self.tracker.update_metrics([loss.item()] + performance, n_samples=data_batch["input"].size(0)) return loss, output def _add_grad_from_prox_regularized_loss(self): assert self.conf.local_prox_term >= 0 if self.conf.local_prox_term != 0: assert self.conf.weight_decay == 0 assert self.conf.optimizer == "sgd" assert self.conf.momentum_factor == 0 for _param, _init_param in zip(self.model.parameters(), self.init_model.parameters()): if _param.grad is not None: _param.grad.data.add_((_param.data - _init_param.data) * self.conf.local_prox_term) def _local_training_with_self_distillation(self, loss, output, data_batch): if self.conf.self_distillation > 0: loss = loss * ( 1 - self.conf.self_distillation ) + self.conf.self_distillation * self._divergence( student_logits=output / self.conf.self_distillation_temperature, teacher_logits=self.init_model(data_batch["input"]) / self.conf.self_distillation_temperature, ) return loss def _divergence(self, student_logits, teacher_logits): divergence = F.kl_div( F.log_softmax(student_logits, dim=1), F.softmax(teacher_logits, dim=1), reduction="batchmean", ) # forward KL return divergence def _turn_off_grad(self, model): for param in model.parameters(): param.requires_grad = False return model def _send_model_to_master(self): dist.barrier() self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) sending the model ({self.arch}) back to Master." ) flatten_model = TensorBuffer(list(self.model.state_dict().values())) dist.send(tensor=flatten_model.buffer, dst=0) dist.barrier() def _terminate_comm_round(self): self.model = self.model.cpu() del self.init_model self.scheduler.clean() self.conf.logger.save_json() self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} (client-{self.conf.graph.client_id}) finished one round of federated learning: (comm_round={self.conf.graph.comm_round})." ) def _terminate_by_early_stopping(self): if self.conf.graph.comm_round == -1: dist.barrier() self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} finished the federated learning by early-stopping." ) return True else: return False def _terminate_by_complete_training(self): if self.conf.graph.comm_round == self.conf.n_comm_rounds: dist.barrier() self.conf.logger.log( f"Worker-{self.conf.graph.worker_id} finished the federated learning: (total comm_rounds={self.conf.graph.comm_round})." ) return True else: return False def _is_finished_one_comm_round(self): return True if self.conf.epoch_ >= self.conf.local_n_epochs else False