def accuracy_fed_avg(self, reports): import fl_model # pylint: disable=import-error # Extract updates from reports updates = self.extract_client_updates(reports) # Extract client accuracies accuracies = np.array([report.accuracy for report in reports]) # Determine weighting based on accuracies factor = 8 # Exponentiation factor w = accuracies**factor / sum(accuracies**factor) # Perform weighted averaging avg_update = [ torch.zeros(x.size()) # pylint: disable=no-member for _, x in updates[0] ] for i, update in enumerate(updates): for j, (_, delta) in enumerate(update): # Use weighted average by magnetude of updates avg_update[j] += delta * w[i] # Extract baseline model weights baseline_weights = fl_model.extract_weights(self.model) # Load updated weights into model updated_weights = [] for i, (name, weight) in enumerate(baseline_weights): updated_weights.append((name, weight + avg_update[i])) return updated_weights
def federated_averaging(self, reports): import fl_model # pylint: disable=import-error # Extract updates from reports updates = self.extract_client_updates(reports) # Extract total number of samples total_samples = sum([report.num_samples for report in reports]) # Perform weighted averaging avg_update = [ torch.zeros(x.size()) # pylint: disable=no-member for _, x in updates[0] ] for i, update in enumerate(updates): num_samples = reports[i].num_samples for j, (_, delta) in enumerate(update): # Use weighted average by number of samples avg_update[j] += delta * (num_samples / total_samples) # Extract baseline model weights baseline_weights = fl_model.extract_weights(self.model) # Load updated weights into model updated_weights = [] for i, (name, weight) in enumerate(baseline_weights): updated_weights.append((name, weight + avg_update[i])) return updated_weights
def magnetude_fed_avg(self, reports): import fl_model # pylint: disable=import-error # Extract updates from reports updates = self.extract_client_updates(reports) # Extract update magnetudes magnetudes = [] for update in updates: magnetude = 0 for _, weight in update: magnetude += weight.norm() ** 2 magnetudes.append(np.sqrt(magnetude)) # Perform weighted averaging avg_update = [torch.zeros(x.size()) # pylint: disable=no-member for _, x in updates[0]] for i, update in enumerate(updates): for j, (_, delta) in enumerate(update): # Use weighted average by magnetude of updates avg_update[j] += delta * (magnetudes[i] / sum(magnetudes)) # Extract baseline model weights baseline_weights = fl_model.extract_weights(self.model) # Load updated weights into model updated_weights = [] for i, (name, weight) in enumerate(baseline_weights): updated_weights.append((name, weight + avg_update[i])) return updated_weights
def extract_client_updates(self, reports): import fl_model # pylint: disable=import-error # Extract baseline model weights baseline_weights = fl_model.extract_weights(self.model) # Extract weights from reports weights = [report.weights for report in reports] # Calculate updates from weights updates = [] for weight in weights: update = [] for i, (name, weight) in enumerate(weight): bl_name, baseline = baseline_weights[i] # Ensure correct weight is being updated assert name == bl_name # Calculate update delta = weight - baseline update.append((name, delta)) updates.append(update) return updates
def save_reports(self, round, reports): import fl_model # pylint: disable=import-error if reports: self.saved_reports['round{}'.format(round)] = [(report.client_id, self.flatten_weights( report.weights)) for report in reports] # Extract global weights self.saved_reports['w{}'.format(round)] = self.flatten_weights( fl_model.extract_weights(self.model))
def profiling(self): import fl_model # pylint: disable=import-error # Use all clients for profiling clients = self.clients # Configure clients for training self.configuration(clients) # Train on clients to generate profile weights threads = [Thread(target=client.train) for client in self.clients] [t.start() for t in threads] [t.join() for t in threads] # Recieve client reports reports = self.reporting(clients) # Extract weights from reports weights = [report.weights for report in reports] weights = [self.flatten_weights(weight) for weight in weights] # Extract initial model weights w0 = self.flatten_weights(fl_model.extract_weights(self.model)) # Save as initial previous model weights self.w_previous = w0.copy() # Update initial model using results of profiling # Perform weight aggregation logging.info('Aggregating updates') updated_weights = self.aggregation(reports) # Load updated weights fl_model.load_weights(self.model, updated_weights) # Calculate direction vectors (directors) directors = [(w - w0) for w in weights] # Normalize directors to unit length directors = [d / np.sqrt(np.dot(d, d)) for d in directors] # Initialize punishment factors self.punishment = [0 for _ in range(len(clients))] # Use directors for client profiles self.profiles = [(client, directors[i]) for i, client in enumerate(clients)] return self.profiles
def selection(self): import fl_model # pylint: disable=import-error clients = self.clients clients_per_round = self.config.clients.per_round profiles = self.profiles w_previous = self.w_previous # Extract directors from profiles directors = [d for _, d in profiles] # Extract most recent model weights w_current = self.flatten_weights(fl_model.extract_weights(self.model)) model_direction = w_current - w_previous # Normalize model direction model_direction = model_direction / \ np.sqrt(np.dot(model_direction, model_direction)) # Update previous model weights self.w_previous = w_current # Generate client director scores (closer direction is better) scores = [np.dot(director, model_direction) for director in directors] # Apply punishment for repeatedly selected clients p = self.punishment scores = [x * (0.9)**p[i] for i, x in enumerate(scores)] # Select clients with highest scores sample_clients_index = [] for _ in range(clients_per_round): top_score_index = scores.index(max(scores)) sample_clients_index.append(top_score_index) # Overwrite to avoid reselection scores[top_score_index] = min(scores) - 1 # Extract selected sample clients sample_clients = [clients[i] for i in sample_clients_index] # Update punishment factors self.punishment = [ p[i] + 1 if i in sample_clients_index else 0 for i in range(len(clients)) ] return sample_clients
def train(self): import fl_model # pylint: disable=import-error logging.info('Training on client #{}'.format(self.client_id)) # Perform model training trainloader = fl_model.get_trainloader(self.trainset, self.batch_size) fl_model.train(self.model, trainloader, self.optimizer, self.epochs) # Extract model weights and biases weights = fl_model.extract_weights(self.model) # Generate report for server self.report = Report(self) self.report.weights = weights # Perform model testing if applicable if self.do_test: testloader = fl_model.get_testloader(self.testset, 1000) self.report.accuracy = fl_model.test(self.model, testloader)