def round(self): import fl_model # pylint: disable=import-error # Select clients to participate in the round sample_clients = self.selection() # Configure sample clients self.configuration(sample_clients) # Run clients using multithreading for better parallelism threads = [Thread(target=client.run) for client in sample_clients] [t.start() for t in threads] [t.join() for t in threads] # with Pool() as pool: # processes = [pool.apply_async(client.run, ()) \ # for client in sample_clients] # proc_results = [proc.get() for proc in processes] # Recieve client updates reports = self.reporting(sample_clients) # reports = self.reporting(sample_clients, proc_results) # Perform weight aggregation logging.info('Aggregating updates') updated_weights = self.aggregation(reports) # Load updated weights fl_model.load_weights(self.model, updated_weights) # Extract flattened weights (if applicable) if self.config.paths.reports: self.save_reports(round, reports) # Save updated global model self.save_model(self.model, self.config.paths.model) # Test global model accuracy if self.config.clients.do_test: # Get average accuracy from client reports accuracy = self.accuracy_averaging(reports) else: # Test updated model on server testset = self.loader.get_testset() batch_size = self.config.fl.batch_size testloader = fl_model.get_testloader(testset, batch_size) accuracy = fl_model.test(self.model, testloader) logging.info('Average accuracy: {:.2f}%\n'.format(100 * accuracy)) return accuracy
def train(self): import fl_model # pylint: disable=import-error logging.info('Training on client #{}'.format(self.client_id)) # Perform model training trainloader = fl_model.get_trainloader(self.trainset, self.batch_size) fl_model.train(self.model, trainloader, self.optimizer, self.epochs) # Extract model weights and biases weights = fl_model.extract_weights(self.model) # Generate report for server self.report = Report(self) self.report.weights = weights # Perform model testing if applicable if self.do_test: testloader = fl_model.get_testloader(self.testset, 1000) self.report.accuracy = fl_model.test(self.model, testloader)
def round(self): import fl_model # pylint: disable=import-error # Select clients to participate in the round sample_client_index = self.selection() sample_client = [ client for client, _ in [self.global_weights[i] for i in sample_client_index] ] # 选取一个 # sample_client = [client for client, _ in [self.global_weights[sample_client_index]]] #获取当前状态 s = trans_torch([weight for _, weight in self.global_weights]).reshape(-1) # Configure sample clients self.configuration(sample_client) # Run clients using multithreading for better parallelism threads = [Thread(target=client.run) for client in sample_client] [t.start() for t in threads] [t.join() for t in threads] # with Pool() as pool: # processes = [pool.apply_async(client.run, ()) \ # for client in sample_clients] # proc_results = [proc.get() for proc in processes] # Recieve client updates reports = self.reporting(sample_client) # reports = self.reporting(sample_clients, proc_results) # Perform weight aggregation logging.info('Aggregating updates') updated_weights = self.aggregation(reports) # Load updated weights fl_model.load_weights(self.model, updated_weights) #更新report_weight到global_weight,(转updated_weight为一维) for i in range(len(sample_client_index)): self.global_weights[sample_client_index[i]] = ( sample_client[i], self.flatten_weights(reports[i].weights)) # Extract flattened weights (if applicable) if self.config.paths.reports: self.save_reports(round, reports) # Save updated global model self.save_model(self.model, self.config.paths.model) # Test global model accuracy if self.config.clients.do_test: # Get average accuracy from client reports accuracy = self.accuracy_averaging(reports) else: # Test updated model on server testset = self.loader.get_testset() batch_size = self.config.fl.batch_size testloader = fl_model.get_testloader(testset, batch_size) accuracy = fl_model.test(self.model, testloader) #dqn环境反馈和训练 #需要获得s,a,r,s_(disabled),状态为所有client的weight s_ = trans_torch([weight for _, weight in self.global_weights]).reshape(-1) print("action:", sample_client_index) #这里的2可以修改,作为一个参数 r = (pow(2, (accuracy - self.config.fl.target_accuracy)) - 1) * 10 print("reward:", r) for i in sample_client_index: self.dqn.store_transition(s, i, r, s_) # 只选取一个 # self.dqn.store_transition(s, sample_client_index, r, s_) if self.dqn.memory_counter > MEMORY_CAPACITY: self.dqn.learn() # 记忆库满了就进行学习 s = s_ logging.info('Average accuracy: {:.2f}%\n'.format(100 * accuracy)) return accuracy