def client_update(self, global_model, global_init_model, round_index): self.elapsed_comm_rounds += 1 print(f'***** Client #{self.client_id} *****', flush=True) self.model = copy_model(global_model, self.args.dataset, self.args.arch, dict(self.model.named_buffers())) num_pruned, num_params = get_prune_summary(self.model) cur_prune_rate = num_pruned / num_params #prune_step = math.floor(num_params * self.args.prune_step) eval_score = evaluate(self.model, self.test_loader, verbose=self.args.test_verbosity) if eval_score['Accuracy'][ 0] > self.args.acc_thresh and cur_prune_rate < self.args.prune_percent: # I'm adding 0.001 just to ensure we go clear the target prune_percent. This may not be needed prune_fraction = min( self.args.prune_step, 0.001 + self.args.prune_percent - cur_prune_rate) prune_fixed_amount(self.model, prune_fraction, verbose=self.args.prune_verbosity, glob=True) self.model = copy_model(global_init_model, self.args.dataset, self.args.arch, dict(self.model.named_buffers())) losses = [] accuracies = [] for i in range(self.args.client_epoch): train_score = train(round_index, self.client_id, i, self.model, self.train_loader, lr=self.args.lr, verbose=self.args.train_verbosity) losses.append(train_score['Loss'][-1].data.item()) accuracies.append(train_score['Accuracy'][-1]) mask_log_path = f'{self.args.log_folder}/round{round_index}/c{self.client_id}.mask' client_mask = dict(self.model.named_buffers()) log_obj(mask_log_path, client_mask) num_pruned, num_params = get_prune_summary(self.model) cur_prune_rate = num_pruned / num_params prune_step = math.floor(num_params * self.args.prune_step) print( f"num_pruned {num_pruned}, num_params {num_params}, cur_prune_rate {cur_prune_rate}, prune_step: {prune_step}" ) self.losses[round_index:] = np.array(losses) self.accuracies[round_index:] = np.array(accuracies) self.prune_rates[round_index:] = cur_prune_rate return copy_model(self.model, self.args.dataset, self.args.arch)
def client_update(self, global_model, global_init_model, comm_round): self.elapsed_comm_rounds += 1 print(f'***** Client #{self.client_id} *****', flush=True) self.model = copy_model(global_model, self.args.dataset, self.args.arch) losses = [] accuracies = [] for epoch in range(self.args.client_epoch): train_score = train(comm_round, self.client_id, epoch, self.model, self.train_loader, lr=self.args.lr, verbose=self.args.train_verbosity) losses.append(train_score['Loss'][-1].data.item()) accuracies.append(train_score['Accuracy'][-1]) num_pruned, num_params = get_prune_summary(self.model) cur_prune_rate = num_pruned / num_params print( f"num_pruned {num_pruned}, num_params {num_params}, cur_prune_rate {cur_prune_rate}" ) self.losses[comm_round:] = np.array(losses) self.accuracies[comm_round:] = np.array(accuracies) self.prune_rates[comm_round:] = cur_prune_rate
def server_update(self): self.elapsed_comm_rounds += 1 self.global_models.train() for comm_round in range(self.comm_rounds): selected_clients = np.random.choice( self.num_clients, max(int(self.frac * self.num_clients), 1), replace=False) print('-------------------------------------', flush=True) print( f'Communication Round #{comm_round} Clients={selected_clients}', flush=True) print('-------------------------------------', flush=True) for c in [self.clients[i] for i in selected_clients]: c.client_update(self.global_models, self.global_init_model, comm_round) new_model = fed_avg([c.model for c in self.clients], self.args.dataset, self.args.arch, self.client_data_num) # fed_avg clobbers the mask, so we need to copy it back into the global model global_buffers = dict(self.global_models.named_buffers()) for name, buffer in new_model.named_buffers(): buffer.data.copy_(global_buffers[name]) self.global_models = new_model # server accuracies are not useful for Genesis self.accuracies[comm_round] = 0 # gather client accuracies for k, m in enumerate(self.clients): if k in selected_clients: self.client_accuracies[k][comm_round] = m.evaluate() elif comm_round > 0: self.client_accuracies[k][ comm_round] = self.client_accuracies[k][comm_round - 1] print( f"End of round accuracy: all={self.client_accuracies[:, comm_round].mean()}, " f"participating={self.client_accuracies[selected_clients, comm_round].mean()}" ) # prune global model if appropriate num_pruned, num_params = get_prune_summary(self.global_models) cur_prune_rate = num_pruned / num_params if self.client_accuracies[:, comm_round].mean() > self.args.acc_thresh \ and cur_prune_rate < self.args.prune_percent: prune_fixed_amount(self.global_models, self.args.prune_step, verbose=self.args.prune_verbosity) self.global_models = copy_model( self.global_init_model, self.args.dataset, self.args.arch, dict(self.global_models.named_buffers()))
def aggr(self, models, clients, *args, **kwargs): print("----------Averaging Models--------") weights_per_client = np.array([client.num_data for client in clients], dtype=np.float32) weights_per_client /= np.sum(weights_per_client) aggr_model = fed_avg(models=models, weights=weights_per_client, device=self.args.device) pruned_summary, _, _ = get_prune_summary(aggr_model, name='weight') print(tabulate(pruned_summary, headers='keys', tablefmt='github')) prune_params = get_prune_params(aggr_model) for param, name in prune_params: zeroed_weights = torch.eq(getattr(param, name).data, 0.00).sum().float() prune.l1_unstructured(param, name, int(zeroed_weights)) return aggr_model
def client_update_method1(client_self, global_model, global_init_model): print(f'***** Client #{client_self.client_id} *****', flush=True) # Checking if the client object has been properly initialized assert isinstance(client_self.model, nn.Module), "A model must be a PyTorch module" assert 0 <= client_self.args.prune_percent <= 1, "The prune percentage must be between 0 and 1" assert client_self.args.client_epoch, '"args" must contain a "client_epoch" field' assert client_self.test_loader, "test_loader field does not exist. Check if the client is initialized correctly" assert client_self.train_loader, "train_loader field does not exist. Check if the client is initialized correctly" assert isinstance( client_self.train_loader, torch.utils.data.DataLoader), "train_loader must be a DataLoader type" assert isinstance( client_self.test_loader, torch.utils.data.DataLoader), "test_loader must be a DataLoader type" client_self.model = copy_model(global_model, client_self.args.dataset, client_self.args.arch) num_pruned, num_params = get_prune_summary(client_self.model) cur_prune_rate = num_pruned / num_params prune_step = math.floor(num_params * client_self.args.prune_step) for i in range(client_self.args.client_epoch): print(f'Epoch {i + 1}') train(client_self.model, client_self.train_loader, lr=client_self.args.lr, verbose=client_self.args.train_verbosity) score = evaluate(client_self.model, client_self.test_loader, verbose=client_self.args.test_verbosity) if score['Accuracy'][ 0] > client_self.args.acc_thresh and cur_prune_rate < client_self.args.prune_percent: prune_fixed_amount(client_self.model, prune_step, verbose=client_self.args.prune_verbosity) return copy_model(client_self.model, client_self.args.dataset, client_self.args.arch)
def update(self, *args, **kwargs): """ Interface to server and clients """ self.elapsed_comm_rounds += 1 self.prev_model = copy_model(self.model, self.args.device) print('-----------------------------', flush=True) print(f'| Communication Round: {self.elapsed_comm_rounds} | ', flush=True) print('-----------------------------', flush=True) _, num_pruned, num_total = get_prune_summary(self.model) prune_percent = num_pruned / num_total # global_model pruned at fixed freq # with a fixed pruning step if (self.args.server_prune == True and (self.elapsed_comm_rounds % self.args.server_prune_freq) == 0) and \ (prune_percent < self.args.server_prune_threshold): # prune the model using super_mask self.prune() # reinitialize model with std.dev of init_model self.reinit() client_idxs = np.random.choice( self.num_clients, int(self.args.frac_clients_per_round * self.num_clients), replace=False, ) clients = [self.clients[i] for i in client_idxs] # upload model to selected clients self.upload(clients) # call training loop on all clients for client in clients: client.update() # download models from selected clients models, accs = self.download(clients) avg_accuracy = np.mean(accs, axis=0, dtype=np.float32) print('-----------------------------', flush=True) print(f'| Average Accuracy: {avg_accuracy} | ', flush=True) print('-----------------------------', flush=True) # compute average-model and (prune it by 0.00 ) aggr_model = self.aggr(models, clients) # copy aggregated-model's params to self.model (keep buffer same) self.model = aggr_model _, num_pruned, num_total = get_prune_summary(self.model) prune_percent = num_pruned / num_total wandb.log({ "client_avg_acc": avg_accuracy, "comm_round": self.elapsed_comm_rounds, "global_prune_percent": prune_percent }) print('Saving global model') torch.save( self.model.state_dict(), f"./checkpoints/server_model_{self.elapsed_comm_rounds}.pt")
def update(self) -> None: """ Interface to Server """ print(f"\n----------Client:{self.idx} Update---------------------") print(f'----------User Class ids: {self.class_idxs}------------') print(f"Evaluating Global model ") metrics = self.eval(self.global_model) accuracy = metrics['Accuracy'][0] print(f'Global model accuracy: {accuracy}') prune_summmary, num_zeros, num_global = get_prune_summary(model=self.global_model, name='weight') prune_rate = num_zeros / num_global print('Global model prune percentage: {}'.format(prune_rate)) if self.cur_prune_rate < self.args.prune_threshold: if accuracy > self.eita: self.cur_prune_rate = min(self.cur_prune_rate + self.args.prune_step, self.args.prune_threshold) if self.cur_prune_rate > prune_rate: l1_prune(model=self.global_model, amount=self.cur_prune_rate, name='weight', verbose=self.args.prune_verbose) # reinitialize model with init_params source_params = dict( self.global_init_model.named_parameters()) for name, param in self.global_model.named_parameters(): param.data.copy_(source_params[name].data) self.prune_rates.append(self.cur_prune_rate) else: # reprune by the downloaded global-model(important) # REVIEW: Rather than pruning each layer by orig_global_pruned_%, # pruned each layer by its' orig_pruned_% params_to_prune = get_prune_params(self.global_model) for param, name in params_to_prune: amount = torch.eq(getattr(param, name), 0.00).sum().float() prune.l1_unstructured(param, name, amount=int(amount)) self.prune_rates.append(prune_rate) self.model = self.global_model self.eita = self.eita_hat else: # reprune by the downloaded global-model(important) # REVIEW: Rather than pruning each layer by orig_global_pruned_%, # pruned each layer by its' orig_pruned_% params_to_prune = get_prune_params(self.global_model) for param, name in params_to_prune: amount = torch.eq(getattr(param, name), 0.00).sum().float() prune.l1_unstructured(param, name, amount=int(amount)) self.eita *= self.alpha self.model = self.global_model self.prune_rates.append(prune_rate) else: if self.cur_prune_rate > prune_rate: l1_prune(model=self.global_model, amount=self.cur_prune_rate, name='weight', verbose=self.args.prune_verbose) source_params = dict(self.global_init_model.named_parameters()) for name, param in self.global_model.named_parameters(): param.data.copy_(source_params[name].data) self.prune_rates.append(self.cur_prune_rate) else: # reprune by the downloaded global-model(not important) params_to_prune = get_prune_params(self.global_model) for param, name in params_to_prune: amount = torch.eq(getattr(param, name), 0.00).sum().float() prune.l1_unstructured(param, name, amount=int(amount)) self.prune_rates.append(prune_rate) self.model = self.global_model print(f"\nTraining local model") self.train(self.elapsed_comm_rounds) print(f"\nEvaluating Trained Model") metrics = self.eval(self.model) print(f'Trained model accuracy: {metrics["Accuracy"][0]}') wandb.log({f"{self.idx}_cur_prune_rate": self.cur_prune_rate}) wandb.log({f"{self.idx}_eita": self.eita}) wandb.log( {f"{self.idx}_percent_pruned": self.prune_rates[-1]}) for key, thing in metrics.items(): if(isinstance(thing, list)): wandb.log({f"{self.idx}_{key}": thing[0]}) else: wandb.log({f"{self.idx}_{key}": thing}) self.save(self.model) self.elapsed_comm_rounds += 1